diff --git a/pyrefly/lib/alt/special_calls.rs b/pyrefly/lib/alt/special_calls.rs index 51dd7d83e..f2e77ecba 100644 --- a/pyrefly/lib/alt/special_calls.rs +++ b/pyrefly/lib/alt/special_calls.rs @@ -372,6 +372,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { // fresh vars and solve them during the is_subset_eq check below. let protocol_instance_ty = self.instantiate_fresh_class(cls); if let Some(object_type) = &object_type { + let mut unsafe_overlap_errors = vec![]; for field_name in &protocol_metadata.members { if !self.has_attr(object_type, field_name) { // It's okay if the field is missing, since @@ -395,24 +396,29 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { "runtime_checkable_protocol_unsafe_overlap", ); if !self.is_subset_eq(&field_ty, &protocol_field_ty) { - errors.add( - range, - ErrorInfo::Kind(ErrorKind::InvalidArgument), - vec1![ - format!( - "Runtime checkable protocol `{}` has an unsafe overlap with type `{}`", - cls.name(), - self.for_display(object_type.clone()) - ), - format!( - "Attribute `{}` has incompatible types: expected `{}`, got `{}`", - field_name, - self.for_display(protocol_field_ty), - self.for_display(field_ty), - ), - ]); + unsafe_overlap_errors.push( + format!( + "Attribute `{}` has incompatible types: expected `{}`, got `{}`", + field_name, + self.for_display(protocol_field_ty), + self.for_display(field_ty), + ), + ); } } + if !unsafe_overlap_errors.is_empty() { + let mut full_msg = vec1![format!( + "Runtime checkable protocol `{}` has an unsafe overlap with type `{}`", + cls.name(), + self.for_display(object_type.clone()) + )]; + full_msg.extend(unsafe_overlap_errors); + errors.add( + range, + ErrorInfo::Kind(ErrorKind::InvalidArgument), + full_msg, + ); + } } } } diff --git a/pyrefly/lib/test/protocol.rs b/pyrefly/lib/test/protocol.rs index 9bd191b19..aa942ba5c 100644 --- a/pyrefly/lib/test/protocol.rs +++ b/pyrefly/lib/test/protocol.rs @@ -668,23 +668,22 @@ issubclass(No, UnsafeProtocol) # E: Runtime checkable protocol `UnsafeProtocol` ); testcase!( - bug = "@runtime_checkable doesn't propagate through inheritance", test_runtime_checkable_unsafe_overlap_with_inheritance, r#" from typing import Protocol, runtime_checkable @runtime_checkable class UnsafeProtocol(Protocol): def foo(self) -> int: ... -@runtime_checkable # E: @runtime_checkable can only be applied to Protocol classes -class ChildUnsafeProtocol(UnsafeProtocol): +@runtime_checkable +class ChildUnsafeProtocol(UnsafeProtocol, Protocol): def bar(self) -> str: ... class No: def foo(self) -> str: return "not an int" def bar(self) -> int: return 42 -isinstance(No(), ChildUnsafeProtocol) -issubclass(No, ChildUnsafeProtocol) +isinstance(No(), ChildUnsafeProtocol) # E: Runtime checkable protocol `ChildUnsafeProtocol` has an unsafe overlap with type `No` +issubclass(No, ChildUnsafeProtocol) # E: Runtime checkable protocol `ChildUnsafeProtocol` has an unsafe overlap with type `No` "#, );