Skip to content

Commit 7ab3bb7

Browse files
committed
[ty] impl VarianceInferable for KnownInstanceType
1 parent e64d772 commit 7ab3bb7

File tree

2 files changed

+172
-9
lines changed

2 files changed

+172
-9
lines changed

crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,65 @@ static_assert(not is_assignable_to(C[B], C[A]))
790790
static_assert(not is_assignable_to(C[A], C[B]))
791791
```
792792

793+
## Type aliases
794+
795+
The variance of the type alias matches the variance of the value type (RHS type).
796+
797+
```py
798+
from ty_extensions import static_assert, is_subtype_of
799+
from typing import Literal
800+
801+
class Covariant[T]:
802+
def get(self) -> T:
803+
raise ValueError
804+
805+
type CovariantLiteral1 = Covariant[Literal[1]]
806+
type CovariantInt = Covariant[int]
807+
type MyCovariant[T] = Covariant[T]
808+
809+
static_assert(is_subtype_of(CovariantLiteral1, CovariantInt))
810+
static_assert(is_subtype_of(MyCovariant[Literal[1]], MyCovariant[int]))
811+
812+
class Contravariant[T]:
813+
def set(self, value: T):
814+
pass
815+
816+
type ContravariantLiteral1 = Contravariant[Literal[1]]
817+
type ContravariantInt = Contravariant[int]
818+
type MyContravariant[T] = Contravariant[T]
819+
820+
static_assert(is_subtype_of(ContravariantInt, ContravariantLiteral1))
821+
static_assert(is_subtype_of(MyContravariant[int], MyContravariant[Literal[1]]))
822+
823+
class Invariant[T]:
824+
def get(self) -> T:
825+
raise ValueError
826+
827+
def set(self, value: T):
828+
pass
829+
830+
type InvariantLiteral1 = Invariant[Literal[1]]
831+
type InvariantInt = Invariant[int]
832+
type MyInvariant[T] = Invariant[T]
833+
834+
static_assert(not is_subtype_of(InvariantInt, InvariantLiteral1))
835+
static_assert(not is_subtype_of(InvariantLiteral1, InvariantInt))
836+
static_assert(not is_subtype_of(MyInvariant[Literal[1]], MyInvariant[int]))
837+
static_assert(not is_subtype_of(MyInvariant[int], MyInvariant[Literal[1]]))
838+
839+
class Bivariant[T]:
840+
pass
841+
842+
type BivariantLiteral1 = Bivariant[Literal[1]]
843+
type BivariantInt = Bivariant[int]
844+
type MyBivariant[T] = Bivariant[T]
845+
846+
static_assert(is_subtype_of(BivariantInt, BivariantLiteral1))
847+
static_assert(is_subtype_of(BivariantLiteral1, BivariantInt))
848+
static_assert(is_subtype_of(MyBivariant[Literal[1]], MyBivariant[int]))
849+
static_assert(is_subtype_of(MyBivariant[int], MyBivariant[Literal[1]]))
850+
```
851+
793852
## Inheriting from generic classes with inferred variance
794853

795854
When inheriting from a generic class with our type variable substituted in, we count its occurrences

crates/ty_python_semantic/src/types.rs

Lines changed: 113 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ use crate::types::enums::{enum_metadata, is_single_member_enum};
5252
use crate::types::function::{
5353
DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction,
5454
};
55+
pub(crate) use crate::types::generics::GenericContext;
5556
use crate::types::generics::{
56-
GenericContext, InferableTypeVars, PartialSpecialization, Specialization, bind_typevar,
57-
typing_self, walk_generic_context,
57+
InferableTypeVars, PartialSpecialization, Specialization, bind_typevar, typing_self,
58+
walk_generic_context,
5859
};
5960
use crate::types::infer::infer_unpack_types;
6061
use crate::types::mro::{Mro, MroError, MroIterator};
@@ -7294,6 +7295,7 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
72947295
.collect(),
72957296
Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar),
72967297
Type::TypeIs(type_is_type) => type_is_type.variance_of(db, typevar),
7298+
Type::KnownInstance(known_instance) => known_instance.variance_of(db, typevar),
72977299
Type::Dynamic(_)
72987300
| Type::Never
72997301
| Type::WrapperDescriptor(_)
@@ -7308,7 +7310,6 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
73087310
| Type::LiteralString
73097311
| Type::BytesLiteral(_)
73107312
| Type::SpecialForm(_)
7311-
| Type::KnownInstance(_)
73127313
| Type::AlwaysFalsy
73137314
| Type::AlwaysTruthy
73147315
| Type::BoundSuper(_)
@@ -7528,6 +7529,17 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
75287529
}
75297530
}
75307531

7532+
impl<'db> VarianceInferable<'db> for KnownInstanceType<'db> {
7533+
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
7534+
match self {
7535+
KnownInstanceType::TypeAliasType(type_alias) => {
7536+
type_alias.raw_value_type(db).variance_of(db, typevar)
7537+
}
7538+
_ => TypeVarVariance::Bivariant,
7539+
}
7540+
}
7541+
}
7542+
75317543
impl<'db> KnownInstanceType<'db> {
75327544
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
75337545
match self {
@@ -10814,12 +10826,7 @@ impl<'db> PEP695TypeAliasType<'db> {
1081410826

1081510827
#[salsa::tracked(cycle_fn=value_type_cycle_recover, cycle_initial=value_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
1081610828
pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> {
10817-
let scope = self.rhs_scope(db);
10818-
let module = parsed_module(db, scope.file(db)).load(db);
10819-
let type_alias_stmt_node = scope.node(db).expect_type_alias();
10820-
let definition = self.definition(db);
10821-
let value_type =
10822-
definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value);
10829+
let value_type = self.raw_value_type(db);
1082310830

1082410831
if let Some(generic_context) = self.generic_context(db) {
1082510832
let specialization = self
@@ -10832,6 +10839,15 @@ impl<'db> PEP695TypeAliasType<'db> {
1083210839
}
1083310840
}
1083410841

10842+
pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> {
10843+
let scope = self.rhs_scope(db);
10844+
let module = parsed_module(db, scope.file(db)).load(db);
10845+
let type_alias_stmt_node = scope.node(db).expect_type_alias();
10846+
let definition = self.definition(db);
10847+
10848+
definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value)
10849+
}
10850+
1083510851
pub(crate) fn apply_specialization(
1083610852
self,
1083710853
db: &'db dyn Db,
@@ -11011,6 +11027,13 @@ impl<'db> TypeAliasType<'db> {
1101111027
}
1101211028
}
1101311029

11030+
pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> {
11031+
match self {
11032+
TypeAliasType::PEP695(type_alias) => type_alias.raw_value_type(db),
11033+
TypeAliasType::ManualPEP695(type_alias) => type_alias.value(db),
11034+
}
11035+
}
11036+
1101411037
pub(crate) fn as_pep_695_type_alias(self) -> Option<PEP695TypeAliasType<'db>> {
1101511038
match self {
1101611039
TypeAliasType::PEP695(type_alias) => Some(type_alias),
@@ -11843,4 +11866,85 @@ pub(crate) mod tests {
1184311866
.build();
1184411867
assert_eq!(intersection.display(&db).to_string(), "Never");
1184511868
}
11869+
11870+
#[test]
11871+
fn type_alias_variance() {
11872+
use crate::db::tests::TestDb;
11873+
use crate::place::global_symbol;
11874+
11875+
fn get_type_alias<'db>(db: &'db TestDb, name: &str) -> PEP695TypeAliasType<'db> {
11876+
let module = ruff_db::files::system_path_to_file(db, "/src/a.py").unwrap();
11877+
let ty = global_symbol(db, module, name).place.expect_type();
11878+
let Type::KnownInstance(KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(
11879+
type_alias,
11880+
))) = ty
11881+
else {
11882+
panic!("Expected `{name}` to be a type alias");
11883+
};
11884+
type_alias
11885+
}
11886+
fn get_bound_typevar<'db>(
11887+
db: &'db TestDb,
11888+
type_alias: PEP695TypeAliasType<'db>,
11889+
) -> BoundTypeVarInstance<'db> {
11890+
let generic_context = type_alias.generic_context(db).unwrap();
11891+
generic_context.variables(db).next().unwrap()
11892+
}
11893+
11894+
let mut db = setup_db();
11895+
db.write_dedented(
11896+
"/src/a.py",
11897+
r#"
11898+
class Covariant[T]:
11899+
def get(self) -> T:
11900+
raise ValueError
11901+
11902+
class Contravariant[T]:
11903+
def set(self, value: T):
11904+
pass
11905+
11906+
class Invariant[T]:
11907+
def get(self) -> T:
11908+
raise ValueError
11909+
def set(self, value: T):
11910+
pass
11911+
11912+
class Bivariant[T]:
11913+
pass
11914+
11915+
type CovariantAlias[T] = Covariant[T]
11916+
type ContravariantAlias[T] = Contravariant[T]
11917+
type InvariantAlias[T] = Invariant[T]
11918+
type BivariantAlias[T] = Bivariant[T]
11919+
"#,
11920+
)
11921+
.unwrap();
11922+
let covariant = get_type_alias(&db, "CovariantAlias");
11923+
assert_eq!(
11924+
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(covariant))
11925+
.variance_of(&db, get_bound_typevar(&db, covariant)),
11926+
TypeVarVariance::Covariant
11927+
);
11928+
11929+
let contravariant = get_type_alias(&db, "ContravariantAlias");
11930+
assert_eq!(
11931+
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(contravariant))
11932+
.variance_of(&db, get_bound_typevar(&db, contravariant)),
11933+
TypeVarVariance::Contravariant
11934+
);
11935+
11936+
let invariant = get_type_alias(&db, "InvariantAlias");
11937+
assert_eq!(
11938+
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(invariant))
11939+
.variance_of(&db, get_bound_typevar(&db, invariant)),
11940+
TypeVarVariance::Invariant
11941+
);
11942+
11943+
let bivariant = get_type_alias(&db, "BivariantAlias");
11944+
assert_eq!(
11945+
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(bivariant))
11946+
.variance_of(&db, get_bound_typevar(&db, bivariant)),
11947+
TypeVarVariance::Bivariant
11948+
);
11949+
}
1184611950
}

0 commit comments

Comments
 (0)