Skip to content

Commit bdccb37

Browse files
authored
[ty] Apply function specialization to all overloads (astral-sh#18020)
Function literals have an optional specialization, which is applied to the parameter/return type annotations lazily when the function's signature is requested. We were previously only applying this specialization to the final overload of an overloaded function. This manifested most visibly for `list.__add__`, which has an overloaded definition in the typeshed: https://github.com/astral-sh/ruff/blob/b398b8363104347fe80f1d5241718f90fb637f84/crates/ty_vendored/vendor/typeshed/stdlib/builtins.pyi#L1069-L1072 Closes astral-sh/ty#314
1 parent 3ccc0ed commit bdccb37

File tree

5 files changed

+110
-86
lines changed

5 files changed

+110
-86
lines changed

crates/ty_python_semantic/resources/mdtest/generics/legacy/classes.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,28 @@ reveal_type(d.method3()) # revealed: SomeProtocol[int]
454454
reveal_type(d.method3().x) # revealed: int
455455
```
456456

457+
When a method is overloaded, the specialization is applied to all overloads.
458+
459+
```py
460+
from typing import overload, Generic, TypeVar
461+
462+
S = TypeVar("S")
463+
464+
class WithOverloadedMethod(Generic[T]):
465+
@overload
466+
def method(self, x: T) -> T:
467+
return x
468+
469+
@overload
470+
def method(self, x: S) -> S | T:
471+
return x
472+
473+
def method(self, x: S | T) -> S | T:
474+
return x
475+
476+
reveal_type(WithOverloadedMethod[int].method) # revealed: Overload[(self, x: int) -> int, (self, x: S) -> S | int]
477+
```
478+
457479
## Cyclic class definitions
458480

459481
### F-bounded quantification

crates/ty_python_semantic/resources/mdtest/generics/pep695/classes.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,26 @@ reveal_type(c.method2()) # revealed: str
347347
reveal_type(c.method3()) # revealed: LinkedList[int]
348348
```
349349

350+
When a method is overloaded, the specialization is applied to all overloads.
351+
352+
```py
353+
from typing import overload
354+
355+
class WithOverloadedMethod[T]:
356+
@overload
357+
def method(self, x: T) -> T:
358+
return x
359+
360+
@overload
361+
def method[S](self, x: S) -> S | T:
362+
return x
363+
364+
def method[S](self, x: S | T) -> S | T:
365+
return x
366+
367+
reveal_type(WithOverloadedMethod[int].method) # revealed: Overload[(self, x: int) -> int, (self, x: S) -> S | int]
368+
```
369+
350370
## Cyclic class definitions
351371

352372
### F-bounded quantification

crates/ty_python_semantic/src/types.rs

Lines changed: 43 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3418,16 +3418,10 @@ impl<'db> Type<'db> {
34183418

34193419
Type::BoundMethod(bound_method) => {
34203420
let signature = bound_method.function(db).signature(db);
3421-
Signatures::single(match signature {
3422-
FunctionSignature::Single(signature) => {
3423-
CallableSignature::single(self, signature.clone())
3424-
.with_bound_type(bound_method.self_instance(db))
3425-
}
3426-
FunctionSignature::Overloaded(signatures, _) => {
3427-
CallableSignature::from_overloads(self, signatures.iter().cloned())
3428-
.with_bound_type(bound_method.self_instance(db))
3429-
}
3430-
})
3421+
Signatures::single(
3422+
CallableSignature::from_overloads(self, signature.overloads.iter().cloned())
3423+
.with_bound_type(bound_method.self_instance(db)),
3424+
)
34313425
}
34323426

34333427
Type::MethodWrapper(
@@ -3785,14 +3779,7 @@ impl<'db> Type<'db> {
37853779
Signatures::single(signature)
37863780
}
37873781

3788-
_ => Signatures::single(match function_type.signature(db) {
3789-
FunctionSignature::Single(signature) => {
3790-
CallableSignature::single(self, signature.clone())
3791-
}
3792-
FunctionSignature::Overloaded(signatures, _) => {
3793-
CallableSignature::from_overloads(self, signatures.iter().cloned())
3794-
}
3795-
}),
3782+
_ => Signatures::single(function_type.signature(db).overloads.clone()),
37963783
},
37973784

37983785
Type::ClassLiteral(class) => match class.known(db) {
@@ -6561,46 +6548,21 @@ bitflags! {
65616548
}
65626549
}
65636550

6564-
/// A function signature, which can be either a single signature or an overloaded signature.
6551+
/// A function signature, which optionally includes an implementation signature if the function is
6552+
/// overloaded.
65656553
#[derive(Clone, Debug, PartialEq, Eq, Hash, salsa::Update)]
6566-
pub(crate) enum FunctionSignature<'db> {
6567-
/// A single function signature.
6568-
Single(Signature<'db>),
6569-
6570-
/// An overloaded function signature containing the `@overload`-ed signatures and an optional
6571-
/// implementation signature.
6572-
Overloaded(Vec<Signature<'db>>, Option<Signature<'db>>),
6554+
pub(crate) struct FunctionSignature<'db> {
6555+
pub(crate) overloads: CallableSignature<'db>,
6556+
pub(crate) implementation: Option<Signature<'db>>,
65736557
}
65746558

65756559
impl<'db> FunctionSignature<'db> {
6576-
/// Returns a slice of all signatures.
6577-
///
6578-
/// For an overloaded function, this only includes the `@overload`-ed signatures and not the
6579-
/// implementation signature.
6580-
pub(crate) fn as_slice(&self) -> &[Signature<'db>] {
6581-
match self {
6582-
Self::Single(signature) => std::slice::from_ref(signature),
6583-
Self::Overloaded(signatures, _) => signatures,
6584-
}
6585-
}
6586-
6587-
/// Returns an iterator over the signatures.
6588-
pub(crate) fn iter(&self) -> Iter<Signature<'db>> {
6589-
self.as_slice().iter()
6590-
}
6591-
65926560
/// Returns the "bottom" signature (subtype of all fully-static signatures.)
65936561
pub(crate) fn bottom(db: &'db dyn Db) -> Self {
6594-
Self::Single(Signature::bottom(db))
6595-
}
6596-
}
6597-
6598-
impl<'db> IntoIterator for &'db FunctionSignature<'db> {
6599-
type Item = &'db Signature<'db>;
6600-
type IntoIter = Iter<'db, Signature<'db>>;
6601-
6602-
fn into_iter(self) -> Self::IntoIter {
6603-
self.iter()
6562+
FunctionSignature {
6563+
overloads: CallableSignature::single(Type::any(), Signature::bottom(db)),
6564+
implementation: None,
6565+
}
66046566
}
66056567
}
66066568

@@ -6671,7 +6633,7 @@ impl<'db> FunctionType<'db> {
66716633
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> Type<'db> {
66726634
Type::Callable(CallableType::from_overloads(
66736635
db,
6674-
self.signature(db).iter().cloned(),
6636+
self.signature(db).overloads.iter().cloned(),
66756637
))
66766638
}
66776639

@@ -6739,20 +6701,32 @@ impl<'db> FunctionType<'db> {
67396701
/// would depend on the function's AST and rerun for every change in that file.
67406702
#[salsa::tracked(returns(ref), cycle_fn=signature_cycle_recover, cycle_initial=signature_cycle_initial)]
67416703
pub(crate) fn signature(self, db: &'db dyn Db) -> FunctionSignature<'db> {
6704+
let specialization = self.specialization(db);
67426705
if let Some(overloaded) = self.to_overloaded(db) {
6743-
FunctionSignature::Overloaded(
6744-
overloaded
6745-
.overloads
6746-
.iter()
6747-
.copied()
6748-
.map(|overload| overload.internal_signature(db))
6749-
.collect(),
6750-
overloaded
6751-
.implementation
6752-
.map(|implementation| implementation.internal_signature(db)),
6753-
)
6706+
FunctionSignature {
6707+
overloads: CallableSignature::from_overloads(
6708+
Type::FunctionLiteral(self),
6709+
overloaded.overloads.iter().copied().map(|overload| {
6710+
overload
6711+
.internal_signature(db)
6712+
.apply_optional_specialization(db, specialization)
6713+
}),
6714+
),
6715+
implementation: overloaded.implementation.map(|implementation| {
6716+
implementation
6717+
.internal_signature(db)
6718+
.apply_optional_specialization(db, specialization)
6719+
}),
6720+
}
67546721
} else {
6755-
FunctionSignature::Single(self.internal_signature(db))
6722+
FunctionSignature {
6723+
overloads: CallableSignature::single(
6724+
Type::FunctionLiteral(self),
6725+
self.internal_signature(db)
6726+
.apply_optional_specialization(db, specialization),
6727+
),
6728+
implementation: None,
6729+
}
67566730
}
67576731
}
67586732

@@ -6774,17 +6748,13 @@ impl<'db> FunctionType<'db> {
67746748
let index = semantic_index(db, scope.file(db));
67756749
GenericContext::from_type_params(db, index, type_params)
67766750
});
6777-
let mut signature = Signature::from_function(
6751+
Signature::from_function(
67786752
db,
67796753
generic_context,
67806754
self.inherited_generic_context(db),
67816755
definition,
67826756
function_stmt_node,
6783-
);
6784-
if let Some(specialization) = self.specialization(db) {
6785-
signature = signature.apply_specialization(db, specialization);
6786-
}
6787-
signature
6757+
)
67886758
}
67896759

67906760
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
@@ -6854,7 +6824,7 @@ impl<'db> FunctionType<'db> {
68546824
typevars: &mut FxOrderSet<TypeVarInstance<'db>>,
68556825
) {
68566826
let signatures = self.signature(db);
6857-
for signature in signatures {
6827+
for signature in &signatures.overloads {
68586828
signature.find_legacy_typevars(db, typevars);
68596829
}
68606830
}
@@ -7114,6 +7084,7 @@ impl<'db> BoundMethodType<'db> {
71147084
db,
71157085
self.function(db)
71167086
.signature(db)
7087+
.overloads
71177088
.iter()
71187089
.map(signatures::Signature::bind_self),
71197090
))

crates/ty_python_semantic/src/types/display.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use crate::types::class::{ClassLiteral, ClassType, GenericAlias};
1010
use crate::types::generics::{GenericContext, Specialization};
1111
use crate::types::signatures::{Parameter, Parameters, Signature};
1212
use crate::types::{
13-
CallableType, FunctionSignature, IntersectionType, KnownClass, MethodWrapperKind, Protocol,
14-
StringLiteralType, SubclassOfInner, Type, TypeVarBoundOrConstraints, TypeVarInstance,
15-
UnionType, WrapperDescriptorKind,
13+
CallableType, IntersectionType, KnownClass, MethodWrapperKind, Protocol, StringLiteralType,
14+
SubclassOfInner, Type, TypeVarBoundOrConstraints, TypeVarInstance, UnionType,
15+
WrapperDescriptorKind,
1616
};
1717
use crate::{Db, FxOrderSet};
1818

@@ -118,8 +118,8 @@ impl Display for DisplayRepresentation<'_> {
118118
// the generic type parameters to the signature, i.e.
119119
// show `def foo[T](x: T) -> T`.
120120

121-
match signature {
122-
FunctionSignature::Single(signature) => {
121+
match signature.overloads.as_slice() {
122+
[signature] => {
123123
write!(
124124
f,
125125
// "def {name}{specialization}{signature}",
@@ -128,7 +128,7 @@ impl Display for DisplayRepresentation<'_> {
128128
signature = signature.display(self.db)
129129
)
130130
}
131-
FunctionSignature::Overloaded(signatures, _) => {
131+
signatures => {
132132
// TODO: How to display overloads?
133133
f.write_str("Overload[")?;
134134
let mut join = f.join(", ");
@@ -146,8 +146,8 @@ impl Display for DisplayRepresentation<'_> {
146146
// TODO: use the specialization from the method. Similar to the comment above
147147
// about the function specialization,
148148

149-
match function.signature(self.db) {
150-
FunctionSignature::Single(signature) => {
149+
match function.signature(self.db).overloads.as_slice() {
150+
[signature] => {
151151
write!(
152152
f,
153153
"bound method {instance}.{method}{signature}",
@@ -156,7 +156,7 @@ impl Display for DisplayRepresentation<'_> {
156156
signature = signature.bind_self().display(self.db)
157157
)
158158
}
159-
FunctionSignature::Overloaded(signatures, _) => {
159+
signatures => {
160160
// TODO: How to display overloads?
161161
f.write_str("Overload[")?;
162162
let mut join = f.join(", ");

crates/ty_python_semantic/src/types/signatures.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ impl<'db> CallableSignature<'db> {
195195
self.overloads.iter()
196196
}
197197

198+
pub(crate) fn as_slice(&self) -> &[Signature<'db>] {
199+
self.overloads.as_slice()
200+
}
201+
198202
fn replace_callable_type(&mut self, before: Type<'db>, after: Type<'db>) {
199203
if self.callable_type == before {
200204
self.callable_type = after;
@@ -309,12 +313,16 @@ impl<'db> Signature<'db> {
309313
}
310314
}
311315

312-
pub(crate) fn apply_specialization(
313-
&self,
316+
pub(crate) fn apply_optional_specialization(
317+
self,
314318
db: &'db dyn Db,
315-
specialization: Specialization<'db>,
319+
specialization: Option<Specialization<'db>>,
316320
) -> Self {
317-
self.apply_type_mapping(db, specialization.type_mapping())
321+
if let Some(specialization) = specialization {
322+
self.apply_type_mapping(db, specialization.type_mapping())
323+
} else {
324+
self
325+
}
318326
}
319327

320328
pub(crate) fn apply_type_mapping<'a>(
@@ -1743,7 +1751,10 @@ mod tests {
17431751
// With no decorators, internal and external signature are the same
17441752
assert_eq!(
17451753
func.signature(&db),
1746-
&FunctionSignature::Single(expected_sig)
1754+
&FunctionSignature {
1755+
overloads: CallableSignature::single(Type::FunctionLiteral(func), expected_sig),
1756+
implementation: None
1757+
},
17471758
);
17481759
}
17491760
}

0 commit comments

Comments
 (0)