Skip to content

Commit e6a798b

Browse files
authored
[red-knot] Recurse into the types of protocol members when normalizing a protocol's interface (astral-sh#17808)
## Summary Currently red-knot does not understand `Foo` and `Bar` here as being equivalent: ```py from typing import Protocol class A: ... class B: ... class C: ... class Foo(Protocol): x: A | B | C class Bar(Protocol): x: B | A | C ``` Nor does it understand `A | B | Foo` as being equivalent to `Bar | B | A`. This PR fixes that. ## Test Plan new mdtest assertions added that fail on `main`
1 parent 52b0470 commit e6a798b

File tree

3 files changed

+72
-14
lines changed

3 files changed

+72
-14
lines changed

crates/red_knot_python_semantic/resources/mdtest/protocols.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,22 @@ class B: ...
816816
static_assert(is_equivalent_to(A | HasX | B | HasY, B | AlsoHasY | AlsoHasX | A))
817817
```
818818

819+
Protocols are considered equivalent if their members are equivalent, even if those members are
820+
differently ordered unions:
821+
822+
```py
823+
class C: ...
824+
825+
class UnionProto1(Protocol):
826+
x: A | B | C
827+
828+
class UnionProto2(Protocol):
829+
x: C | A | B
830+
831+
static_assert(is_equivalent_to(UnionProto1, UnionProto2))
832+
static_assert(is_equivalent_to(UnionProto1 | A | B, B | UnionProto2 | A))
833+
```
834+
819835
## Intersections of protocols
820836

821837
An intersection of two protocol types `X` and `Y` is equivalent to a protocol type `Z` that inherits

crates/red_knot_python_semantic/src/types/instance.rs

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use super::{ClassType, KnownClass, SubclassOfType, Type};
55
use crate::symbol::{Symbol, SymbolAndQualifiers};
66
use crate::Db;
77

8+
pub(super) use synthesized_protocol::SynthesizedProtocolType;
9+
810
impl<'db> Type<'db> {
911
pub(crate) fn instance(db: &'db dyn Db, class: ClassType<'db>) -> Self {
1012
if class.class_literal(db).0.is_protocol(db) {
@@ -164,7 +166,7 @@ impl<'db> ProtocolInstanceType<'db> {
164166
}
165167
match self.0 {
166168
Protocol::FromClass(_) => Type::ProtocolInstance(Self(Protocol::Synthesized(
167-
SynthesizedProtocolType::new(db, self.0.interface(db)),
169+
SynthesizedProtocolType::new(db, self.0.interface(db).clone()),
168170
))),
169171
Protocol::Synthesized(_) => Type::ProtocolInstance(self),
170172
}
@@ -237,9 +239,7 @@ impl<'db> ProtocolInstanceType<'db> {
237239

238240
/// An enumeration of the two kinds of protocol types: those that originate from a class
239241
/// definition in source code, and those that are synthesized from a set of members.
240-
#[derive(
241-
Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, salsa::Supertype, PartialOrd, Ord,
242-
)]
242+
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
243243
pub(super) enum Protocol<'db> {
244244
FromClass(ClassType<'db>),
245245
Synthesized(SynthesizedProtocolType<'db>),
@@ -260,14 +260,38 @@ impl<'db> Protocol<'db> {
260260
}
261261
}
262262

263-
/// A "synthesized" protocol type that is dissociated from a class definition in source code.
264-
///
265-
/// Two synthesized protocol types with the same members will share the same Salsa ID,
266-
/// making them easy to compare for equivalence. A synthesized protocol type is therefore
267-
/// returned by [`ProtocolInstanceType::normalized`] so that two protocols with the same members
268-
/// will be understood as equivalent even in the context of differently ordered unions or intersections.
269-
#[salsa::interned(debug)]
270-
pub(super) struct SynthesizedProtocolType<'db> {
271-
#[return_ref]
272-
pub(super) interface: ProtocolInterface<'db>,
263+
mod synthesized_protocol {
264+
use crate::db::Db;
265+
use crate::types::protocol_class::ProtocolInterface;
266+
267+
/// A "synthesized" protocol type that is dissociated from a class definition in source code.
268+
///
269+
/// Two synthesized protocol types with the same members will share the same Salsa ID,
270+
/// making them easy to compare for equivalence. A synthesized protocol type is therefore
271+
/// returned by [`super::ProtocolInstanceType::normalized`] so that two protocols with the same members
272+
/// will be understood as equivalent even in the context of differently ordered unions or intersections.
273+
///
274+
/// The constructor method of this type maintains the invariant that a synthesized protocol type
275+
/// is always constructed from a *normalized* protocol interface.
276+
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash, salsa::Update, PartialOrd, Ord)]
277+
pub(in crate::types) struct SynthesizedProtocolType<'db>(SynthesizedProtocolTypeInner<'db>);
278+
279+
impl<'db> SynthesizedProtocolType<'db> {
280+
pub(super) fn new(db: &'db dyn Db, interface: ProtocolInterface<'db>) -> Self {
281+
Self(SynthesizedProtocolTypeInner::new(
282+
db,
283+
interface.normalized(db),
284+
))
285+
}
286+
287+
pub(in crate::types) fn interface(self, db: &'db dyn Db) -> &'db ProtocolInterface<'db> {
288+
self.0.interface(db)
289+
}
290+
}
291+
292+
#[salsa::interned(debug)]
293+
struct SynthesizedProtocolTypeInner<'db> {
294+
#[return_ref]
295+
interface: ProtocolInterface<'db>,
296+
}
273297
}

crates/red_knot_python_semantic/src/types/protocol_class.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ impl<'db> ProtocolInterface<'db> {
9696
pub(super) fn contains_todo(&self, db: &'db dyn Db) -> bool {
9797
self.members().any(|member| member.ty.contains_todo(db))
9898
}
99+
100+
pub(super) fn normalized(self, db: &'db dyn Db) -> Self {
101+
Self(
102+
self.0
103+
.into_iter()
104+
.map(|(name, data)| (name, data.normalized(db)))
105+
.collect(),
106+
)
107+
}
99108
}
100109

101110
#[derive(Debug, PartialEq, Eq, Clone, Hash, salsa::Update)]
@@ -104,6 +113,15 @@ struct ProtocolMemberData<'db> {
104113
qualifiers: TypeQualifiers,
105114
}
106115

116+
impl<'db> ProtocolMemberData<'db> {
117+
fn normalized(self, db: &'db dyn Db) -> Self {
118+
Self {
119+
ty: self.ty.normalized(db),
120+
qualifiers: self.qualifiers,
121+
}
122+
}
123+
}
124+
107125
/// A single member of a protocol interface.
108126
#[derive(Debug, PartialEq, Eq)]
109127
pub(super) struct ProtocolMember<'a, 'db> {

0 commit comments

Comments
 (0)