Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions pyrefly/lib/alt/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use crate::types::read_only::ReadOnlyReason;
use crate::types::type_var::Restriction;
use crate::types::typed_dict::TypedDict;
use crate::types::types::AnyStyle;
use crate::types::types::BoundMethodType;
use crate::types::types::Overload;
use crate::types::types::SuperObj;
use crate::types::types::Type;
Expand Down Expand Up @@ -438,6 +439,9 @@ enum AttributeBase1 {
TypedDict(TypedDict),
/// Attribute lookup on a base as part of a subset check against a protocol.
ProtocolSubset(Box<AttributeBase1>),
/// Bound methods prefer exposing builtin `types.MethodType` attributes but fall back to the
/// underlying function's attributes when the builtin ones are missing.
BoundMethod(BoundMethodType),
Intersect(Vec<AttributeBase1>, Vec<AttributeBase1>),
}

Expand Down Expand Up @@ -1240,6 +1244,38 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
self.lookup_attr_from_attribute_base1((**protocol_base).clone(), attr_name, acc)
}
}
AttributeBase1::TypeQuantified(quantified, class) => {
if let Some(attr) = self.get_bounded_quantified_class_attribute(
quantified.clone(),
class,
attr_name,
) {
acc.found_class_attribute(attr, base);
} else {
acc.not_found(NotFoundOn::ClassObject(class.class_object().dupe(), base));
}
}
AttributeBase1::BoundMethod(bound_func) => {
let method_type_base =
AttributeBase1::ClassInstance(self.stdlib.method_type().clone());
let found_len = acc.found.len();
let not_found_len = acc.not_found.len();
let error_len = acc.internal_error.len();
self.lookup_attr_from_attribute_base1(method_type_base, attr_name, acc);
if acc.found.len() == found_len {
acc.not_found.truncate(not_found_len);
acc.internal_error.truncate(error_len);
let mut func_bases = Vec::new();
self.as_attribute_base1(bound_func.clone().as_type(), &mut func_bases);
for base1 in func_bases {
self.lookup_attr_from_attribute_base1(base1, attr_name, acc);
}
} else {
acc.not_found.truncate(not_found_len);
acc.internal_error.truncate(error_len);
}
}

AttributeBase1::ClassObject(class) => {
let attr = match class {
ClassBase::Quantified(quantified, class) => self
Expand Down Expand Up @@ -1758,9 +1794,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
self.stdlib.function_type().clone()
},
)),
Type::BoundMethod(_) => acc.push(AttributeBase1::ClassInstance(
self.stdlib.method_type().clone(),
)),
Type::BoundMethod(bound_method) => {
acc.push(AttributeBase1::BoundMethod(bound_method.func.clone()));
}
Type::Ellipsis => {
if let Some(cls) = self.stdlib.ellipsis_type() {
acc.push(AttributeBase1::ClassInstance(cls.clone()))
Expand Down Expand Up @@ -2191,6 +2227,24 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
AttributeBase1::ClassObject(class) => {
self.completions_class(class.class_object(), expected_attribute_name, res)
}
AttributeBase1::TypeQuantified(_, class) => {
self.completions_class(class.class_object(), expected_attribute_name, res)
}
AttributeBase1::BoundMethod(bound_func) => {
let before = res.len();
self.completions_class_type(
self.stdlib.method_type(),
expected_attribute_name,
res,
);
if res.len() == before {
let mut func_bases = Vec::new();
self.as_attribute_base1(bound_func.clone().as_type(), &mut func_bases);
for base1 in func_bases {
self.completions_inner1(&base1, expected_attribute_name, res);
}
}
}
AttributeBase1::TypeAny(_) | AttributeBase1::TypeNever => self.completions_class_type(
self.stdlib.builtins_type(),
expected_attribute_name,
Expand Down
35 changes: 35 additions & 0 deletions pyrefly/lib/test/descriptors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,41 @@ C().d = "42"
"#,
);

testcase!(
test_bound_method_preserves_function_attributes_from_descriptor,
r#"
from __future__ import annotations

from typing import Callable


class CachedMethod:
def __init__(self, fn: Callable[[Constraint], int]) -> None:
self._fn = fn

def __get__(self, obj: Constraint | None, owner: type[Constraint]) -> CachedMethod:
return self

def __call__(self, obj: Constraint) -> int:
return self._fn(obj)

def clear_cache(self, obj: Constraint) -> None: ...


def cache_on_self(fn: Callable[[Constraint], int]) -> CachedMethod:
return CachedMethod(fn)


class Constraint:
@cache_on_self
def pointwise_read_writes(self) -> int:
return 0

def clear_cache(self) -> None:
self.pointwise_read_writes.clear_cache(self)
"#,
);

testcase!(
test_class_property_descriptor,
r#"
Expand Down
Loading