Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion pyrefly/lib/binding/class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use starlark_map::small_map::SmallMap;

use crate::binding::base_class::BaseClass;
use crate::binding::base_class::BaseClassGeneric;
use crate::binding::base_class::BaseClassGenericKind;
use crate::binding::binding::AnnotationTarget;
use crate::binding::binding::Binding;
use crate::binding::binding::BindingAbstractClassCheck;
Expand Down Expand Up @@ -111,6 +112,7 @@ impl<'a> BindingsBuilder<'a> {
consistent_override_check_idx: self
.idx_for_promise(KeyConsistentOverrideCheck(def_index)),
abstract_class_check_idx: self.idx_for_promise(KeyAbstractClassCheck(def_index)),
is_protocol: false, // Will be set after processing bases
};
// The user - used for first-usage tracking of any expressions we analyze in a class definition -
// is the `Idx<Key>` of the class object bound to the class name.
Expand All @@ -120,7 +122,7 @@ impl<'a> BindingsBuilder<'a> {
}

pub fn class_def(&mut self, mut x: StmtClassDef, parent: &NestingContext) {
let (mut class_object, class_indices) = self.class_object_and_indices(&x.name);
let (mut class_object, mut class_indices) = self.class_object_and_indices(&x.name);
let mut pydantic_config_dict = PydanticConfigDict::default();
let docstring_range = Docstring::range_from_stmts(x.body.as_slice());
let body = mem::take(&mut x.body);
Expand Down Expand Up @@ -197,6 +199,17 @@ impl<'a> BindingsBuilder<'a> {
base_class
});

// Check if this class directly inherits from Protocol
class_indices.is_protocol = bases.iter().any(|base| {
matches!(
base,
BaseClass::Generic(BaseClassGeneric {
kind: BaseClassGenericKind::Protocol,
..
})
)
});

let mut keywords = Vec::new();
if let Some(args) = &mut x.arguments {
args.keywords.iter_mut().for_each(|keyword| {
Expand Down
23 changes: 23 additions & 0 deletions pyrefly/lib/binding/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,33 @@ impl<'a> BindingsBuilder<'a> {
undecorated_idx: Idx<KeyUndecoratedFunction>,
class_key: Option<Idx<KeyClass>>,
) -> (FunctionStubOrImpl, Option<SelfAssignments>) {
// A method in a Protocol with a stub-like body (docstring-only, docstring + pass,
// docstring + ellipsis) is treated as a stub. This matches Python typing spec
// behavior where Protocol methods don't need implementations.
let is_protocol_stub_body = if self.scopes.is_in_protocol_class() {
match body.as_slice() {
// Docstring only
[stmt] if is_docstring(stmt) => true,
// Docstring + pass
[first, Stmt::Pass(_)] if is_docstring(first) => true,
// Docstring + ellipsis
[first, Stmt::Expr(expr_stmt)]
if is_docstring(first)
&& matches!(expr_stmt.value.as_ref(), Expr::EllipsisLiteral(_)) =>
{
true
}
_ => false,
}
} else {
false
};

let stub_or_impl = if (body.first().is_some_and(is_docstring)
&& decorators.is_abstract_method)
|| is_ellipse(&body)
|| (body.first().is_some_and(is_docstring) && decorators.is_overload)
|| is_protocol_stub_body
{
FunctionStubOrImpl::Stub
} else {
Expand Down
12 changes: 12 additions & 0 deletions pyrefly/lib/binding/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ pub struct ClassIndices {
pub variance_idx: Idx<KeyVariance>,
pub consistent_override_check_idx: Idx<KeyConsistentOverrideCheck>,
pub abstract_class_check_idx: Idx<KeyAbstractClassCheck>,
/// Whether this class directly inherits from Protocol.
pub is_protocol: bool,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -1231,6 +1233,16 @@ impl Scopes {
None
}

/// Check if we're currently inside a Protocol class body.
pub fn is_in_protocol_class(&self) -> bool {
for scope in self.iter_rev() {
if let ScopeKind::Class(class_scope) = &scope.kind {
return class_scope.indices.is_protocol;
}
}
false
}

/// Are we inside an async function or method?
pub fn is_in_async_def(&self) -> bool {
for scope in self.iter_rev() {
Expand Down
30 changes: 30 additions & 0 deletions pyrefly/lib/test/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,33 @@ def test():
p: TrickyProtocol[int] = t # E:
"#,
);

// Regression test for https://github.com/facebook/pyrefly/issues/1916
// Protocol methods with only a docstring should not emit "missing explicit return" errors
testcase!(
test_protocol_method_with_docstring,
r#"
from typing import Protocol
class SortState:
pass
class View(Protocol):
"""A protocol with methods that have docstrings but no body."""
@property
def sort_state(self) -> SortState:
"""Return the current sorting/grouping settings."""
def get_value(self) -> int:
"""Get a value."""
def method_with_ellipsis(self) -> str:
"""This one uses ellipsis."""
...
def method_with_pass(self) -> str:
"""This one uses pass - should still be allowed in protocol."""
pass
"#,
);
Loading