diff --git a/pyrefly/lib/binding/class.rs b/pyrefly/lib/binding/class.rs index 7daa0b33d..3b557d828 100644 --- a/pyrefly/lib/binding/class.rs +++ b/pyrefly/lib/binding/class.rs @@ -33,6 +33,7 @@ use starlark_map::small_map::SmallMap; use crate::binding::base_class::BaseClass; use crate::binding::base_class::BaseClassGeneric; +use crate::binding::base_class::BaseClassGenericKind; use crate::binding::binding::AnnotationTarget; use crate::binding::binding::Binding; use crate::binding::binding::BindingAbstractClassCheck; @@ -111,6 +112,7 @@ impl<'a> BindingsBuilder<'a> { consistent_override_check_idx: self .idx_for_promise(KeyConsistentOverrideCheck(def_index)), abstract_class_check_idx: self.idx_for_promise(KeyAbstractClassCheck(def_index)), + is_protocol: false, // Will be set after processing bases }; // The user - used for first-usage tracking of any expressions we analyze in a class definition - // is the `Idx` of the class object bound to the class name. @@ -120,7 +122,7 @@ impl<'a> BindingsBuilder<'a> { } pub fn class_def(&mut self, mut x: StmtClassDef, parent: &NestingContext) { - let (mut class_object, class_indices) = self.class_object_and_indices(&x.name); + let (mut class_object, mut class_indices) = self.class_object_and_indices(&x.name); let mut pydantic_config_dict = PydanticConfigDict::default(); let docstring_range = Docstring::range_from_stmts(x.body.as_slice()); let body = mem::take(&mut x.body); @@ -197,6 +199,17 @@ impl<'a> BindingsBuilder<'a> { base_class }); + // Check if this class directly inherits from Protocol + class_indices.is_protocol = bases.iter().any(|base| { + matches!( + base, + BaseClass::Generic(BaseClassGeneric { + kind: BaseClassGenericKind::Protocol, + .. + }) + ) + }); + let mut keywords = Vec::new(); if let Some(args) = &mut x.arguments { args.keywords.iter_mut().for_each(|keyword| { diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 4f7281e15..f1096a330 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -526,10 +526,33 @@ impl<'a> BindingsBuilder<'a> { undecorated_idx: Idx, class_key: Option>, ) -> (FunctionStubOrImpl, Option) { + // A method in a Protocol with a stub-like body (docstring-only, docstring + pass, + // docstring + ellipsis) is treated as a stub. This matches Python typing spec + // behavior where Protocol methods don't need implementations. + let is_protocol_stub_body = if self.scopes.is_in_protocol_class() { + match body.as_slice() { + // Docstring only + [stmt] if is_docstring(stmt) => true, + // Docstring + pass + [first, Stmt::Pass(_)] if is_docstring(first) => true, + // Docstring + ellipsis + [first, Stmt::Expr(expr_stmt)] + if is_docstring(first) + && matches!(expr_stmt.value.as_ref(), Expr::EllipsisLiteral(_)) => + { + true + } + _ => false, + } + } else { + false + }; + let stub_or_impl = if (body.first().is_some_and(is_docstring) && decorators.is_abstract_method) || is_ellipse(&body) || (body.first().is_some_and(is_docstring) && decorators.is_overload) + || is_protocol_stub_body { FunctionStubOrImpl::Stub } else { diff --git a/pyrefly/lib/binding/scope.rs b/pyrefly/lib/binding/scope.rs index e60cbe4ed..eea665acd 100644 --- a/pyrefly/lib/binding/scope.rs +++ b/pyrefly/lib/binding/scope.rs @@ -690,6 +690,8 @@ pub struct ClassIndices { pub variance_idx: Idx, pub consistent_override_check_idx: Idx, pub abstract_class_check_idx: Idx, + /// Whether this class directly inherits from Protocol. + pub is_protocol: bool, } #[derive(Clone, Debug)] @@ -1231,6 +1233,16 @@ impl Scopes { None } + /// Check if we're currently inside a Protocol class body. + pub fn is_in_protocol_class(&self) -> bool { + for scope in self.iter_rev() { + if let ScopeKind::Class(class_scope) = &scope.kind { + return class_scope.indices.is_protocol; + } + } + false + } + /// Are we inside an async function or method? pub fn is_in_async_def(&self) -> bool { for scope in self.iter_rev() { diff --git a/pyrefly/lib/test/protocol.rs b/pyrefly/lib/test/protocol.rs index d1f43b4d3..29e87779f 100644 --- a/pyrefly/lib/test/protocol.rs +++ b/pyrefly/lib/test/protocol.rs @@ -847,3 +847,33 @@ def test(): p: TrickyProtocol[int] = t # E: "#, ); + +// Regression test for https://github.com/facebook/pyrefly/issues/1916 +// Protocol methods with only a docstring should not emit "missing explicit return" errors +testcase!( + test_protocol_method_with_docstring, + r#" +from typing import Protocol + +class SortState: + pass + +class View(Protocol): + """A protocol with methods that have docstrings but no body.""" + + @property + def sort_state(self) -> SortState: + """Return the current sorting/grouping settings.""" + + def get_value(self) -> int: + """Get a value.""" + + def method_with_ellipsis(self) -> str: + """This one uses ellipsis.""" + ... + + def method_with_pass(self) -> str: + """This one uses pass - should still be allowed in protocol.""" + pass +"#, +);