diff --git a/pyrefly/lib/alt/attr.rs b/pyrefly/lib/alt/attr.rs index 53a82ed0ca..c7e6f25f3a 100644 --- a/pyrefly/lib/alt/attr.rs +++ b/pyrefly/lib/alt/attr.rs @@ -50,6 +50,7 @@ use crate::types::read_only::ReadOnlyReason; use crate::types::type_var::Restriction; use crate::types::typed_dict::TypedDict; use crate::types::types::AnyStyle; +use crate::types::types::BoundMethodType; use crate::types::types::Overload; use crate::types::types::SuperObj; use crate::types::types::Type; @@ -438,6 +439,9 @@ enum AttributeBase1 { TypedDict(TypedDict), /// Attribute lookup on a base as part of a subset check against a protocol. ProtocolSubset(Box), + /// Bound methods prefer exposing builtin `types.MethodType` attributes but fall back to the + /// underlying function's attributes when the builtin ones are missing. + BoundMethod(BoundMethodType), Intersect(Vec, Vec), } @@ -1240,6 +1244,38 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.lookup_attr_from_attribute_base1((**protocol_base).clone(), attr_name, acc) } } + AttributeBase1::TypeQuantified(quantified, class) => { + if let Some(attr) = self.get_bounded_quantified_class_attribute( + quantified.clone(), + class, + attr_name, + ) { + acc.found_class_attribute(attr, base); + } else { + acc.not_found(NotFoundOn::ClassObject(class.class_object().dupe(), base)); + } + } + AttributeBase1::BoundMethod(bound_func) => { + let method_type_base = + AttributeBase1::ClassInstance(self.stdlib.method_type().clone()); + let found_len = acc.found.len(); + let not_found_len = acc.not_found.len(); + let error_len = acc.internal_error.len(); + self.lookup_attr_from_attribute_base1(method_type_base, attr_name, acc); + if acc.found.len() == found_len { + acc.not_found.truncate(not_found_len); + acc.internal_error.truncate(error_len); + let mut func_bases = Vec::new(); + self.as_attribute_base1(bound_func.clone().as_type(), &mut func_bases); + for base1 in func_bases { + self.lookup_attr_from_attribute_base1(base1, attr_name, acc); + } + } else { + acc.not_found.truncate(not_found_len); + acc.internal_error.truncate(error_len); + } + } + AttributeBase1::ClassObject(class) => { let attr = match class { ClassBase::Quantified(quantified, class) => self @@ -1758,9 +1794,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.stdlib.function_type().clone() }, )), - Type::BoundMethod(_) => acc.push(AttributeBase1::ClassInstance( - self.stdlib.method_type().clone(), - )), + Type::BoundMethod(bound_method) => { + acc.push(AttributeBase1::BoundMethod(bound_method.func.clone())); + } Type::Ellipsis => { if let Some(cls) = self.stdlib.ellipsis_type() { acc.push(AttributeBase1::ClassInstance(cls.clone())) @@ -2191,6 +2227,24 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { AttributeBase1::ClassObject(class) => { self.completions_class(class.class_object(), expected_attribute_name, res) } + AttributeBase1::TypeQuantified(_, class) => { + self.completions_class(class.class_object(), expected_attribute_name, res) + } + AttributeBase1::BoundMethod(bound_func) => { + let before = res.len(); + self.completions_class_type( + self.stdlib.method_type(), + expected_attribute_name, + res, + ); + if res.len() == before { + let mut func_bases = Vec::new(); + self.as_attribute_base1(bound_func.clone().as_type(), &mut func_bases); + for base1 in func_bases { + self.completions_inner1(&base1, expected_attribute_name, res); + } + } + } AttributeBase1::TypeAny(_) | AttributeBase1::TypeNever => self.completions_class_type( self.stdlib.builtins_type(), expected_attribute_name, diff --git a/pyrefly/lib/test/descriptors.rs b/pyrefly/lib/test/descriptors.rs index d119482496..b3bd4edd9e 100644 --- a/pyrefly/lib/test/descriptors.rs +++ b/pyrefly/lib/test/descriptors.rs @@ -206,6 +206,41 @@ C().d = "42" "#, ); +testcase!( + test_bound_method_preserves_function_attributes_from_descriptor, + r#" +from __future__ import annotations + +from typing import Callable + + +class CachedMethod: + def __init__(self, fn: Callable[[Constraint], int]) -> None: + self._fn = fn + + def __get__(self, obj: Constraint | None, owner: type[Constraint]) -> CachedMethod: + return self + + def __call__(self, obj: Constraint) -> int: + return self._fn(obj) + + def clear_cache(self, obj: Constraint) -> None: ... + + +def cache_on_self(fn: Callable[[Constraint], int]) -> CachedMethod: + return CachedMethod(fn) + + +class Constraint: + @cache_on_self + def pointwise_read_writes(self) -> int: + return 0 + + def clear_cache(self) -> None: + self.pointwise_read_writes.clear_cache(self) + "#, +); + testcase!( test_class_property_descriptor, r#"