Skip to content

Commit e4384fc

Browse files
authored
[ty] impl VarianceInferable for KnownInstanceType (#20924)
## Summary Derived from #20900 Implement `VarianceInferable` for `KnownInstanceType` (especially for `KnownInstanceType::TypeAliasType`). The variance of a type alias matches its value type. In normal usage, type aliases are expanded to value types, so the variance of a type alias can be obtained without implementing this. However, for example, if we want to display the variance when hovering over a type alias, we need to be able to obtain the variance of the type alias itself (cf. #20900). ## Test Plan I couldn't come up with a way to test this in mdtest, so I'm testing it in a test submodule at the end of `types.rs`. I also added a test to `mdtest/generics/pep695/variance.md`, but it passes without the changes in this PR.
1 parent 6e7ff07 commit e4384fc

File tree

2 files changed

+183
-9
lines changed

2 files changed

+183
-9
lines changed

crates/ty_python_semantic/resources/mdtest/generics/pep695/variance.md

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,65 @@ static_assert(not is_assignable_to(C[B], C[A]))
790790
static_assert(not is_assignable_to(C[A], C[B]))
791791
```
792792

793+
## Type aliases
794+
795+
The variance of the type alias matches the variance of the value type (RHS type).
796+
797+
```py
798+
from ty_extensions import static_assert, is_subtype_of
799+
from typing import Literal
800+
801+
class Covariant[T]:
802+
def get(self) -> T:
803+
raise ValueError
804+
805+
type CovariantLiteral1 = Covariant[Literal[1]]
806+
type CovariantInt = Covariant[int]
807+
type MyCovariant[T] = Covariant[T]
808+
809+
static_assert(is_subtype_of(CovariantLiteral1, CovariantInt))
810+
static_assert(is_subtype_of(MyCovariant[Literal[1]], MyCovariant[int]))
811+
812+
class Contravariant[T]:
813+
def set(self, value: T):
814+
pass
815+
816+
type ContravariantLiteral1 = Contravariant[Literal[1]]
817+
type ContravariantInt = Contravariant[int]
818+
type MyContravariant[T] = Contravariant[T]
819+
820+
static_assert(is_subtype_of(ContravariantInt, ContravariantLiteral1))
821+
static_assert(is_subtype_of(MyContravariant[int], MyContravariant[Literal[1]]))
822+
823+
class Invariant[T]:
824+
def get(self) -> T:
825+
raise ValueError
826+
827+
def set(self, value: T):
828+
pass
829+
830+
type InvariantLiteral1 = Invariant[Literal[1]]
831+
type InvariantInt = Invariant[int]
832+
type MyInvariant[T] = Invariant[T]
833+
834+
static_assert(not is_subtype_of(InvariantInt, InvariantLiteral1))
835+
static_assert(not is_subtype_of(InvariantLiteral1, InvariantInt))
836+
static_assert(not is_subtype_of(MyInvariant[Literal[1]], MyInvariant[int]))
837+
static_assert(not is_subtype_of(MyInvariant[int], MyInvariant[Literal[1]]))
838+
839+
class Bivariant[T]:
840+
pass
841+
842+
type BivariantLiteral1 = Bivariant[Literal[1]]
843+
type BivariantInt = Bivariant[int]
844+
type MyBivariant[T] = Bivariant[T]
845+
846+
static_assert(is_subtype_of(BivariantInt, BivariantLiteral1))
847+
static_assert(is_subtype_of(BivariantLiteral1, BivariantInt))
848+
static_assert(is_subtype_of(MyBivariant[Literal[1]], MyBivariant[int]))
849+
static_assert(is_subtype_of(MyBivariant[int], MyBivariant[Literal[1]]))
850+
```
851+
793852
## Inheriting from generic classes with inferred variance
794853

795854
When inheriting from a generic class with our type variable substituted in, we count its occurrences

crates/ty_python_semantic/src/types.rs

Lines changed: 124 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ use crate::types::function::{
5555
DataclassTransformerFlags, DataclassTransformerParams, FunctionSpans, FunctionType,
5656
KnownFunction,
5757
};
58+
pub(crate) use crate::types::generics::GenericContext;
5859
use crate::types::generics::{
59-
GenericContext, InferableTypeVars, PartialSpecialization, Specialization, bind_typevar,
60-
typing_self, walk_generic_context,
60+
InferableTypeVars, PartialSpecialization, Specialization, bind_typevar, typing_self,
61+
walk_generic_context,
6162
};
6263
use crate::types::infer::infer_unpack_types;
6364
use crate::types::mro::{Mro, MroError, MroIterator};
@@ -7274,6 +7275,7 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
72747275
.collect(),
72757276
Type::SubclassOf(subclass_of_type) => subclass_of_type.variance_of(db, typevar),
72767277
Type::TypeIs(type_is_type) => type_is_type.variance_of(db, typevar),
7278+
Type::KnownInstance(known_instance) => known_instance.variance_of(db, typevar),
72777279
Type::Dynamic(_)
72787280
| Type::Never
72797281
| Type::WrapperDescriptor(_)
@@ -7288,7 +7290,6 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
72887290
| Type::LiteralString
72897291
| Type::BytesLiteral(_)
72907292
| Type::SpecialForm(_)
7291-
| Type::KnownInstance(_)
72927293
| Type::AlwaysFalsy
72937294
| Type::AlwaysTruthy
72947295
| Type::BoundSuper(_)
@@ -7495,6 +7496,17 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
74957496
}
74967497
}
74977498

7499+
impl<'db> VarianceInferable<'db> for KnownInstanceType<'db> {
7500+
fn variance_of(self, db: &'db dyn Db, typevar: BoundTypeVarInstance<'db>) -> TypeVarVariance {
7501+
match self {
7502+
KnownInstanceType::TypeAliasType(type_alias) => {
7503+
type_alias.raw_value_type(db).variance_of(db, typevar)
7504+
}
7505+
_ => TypeVarVariance::Bivariant,
7506+
}
7507+
}
7508+
}
7509+
74987510
impl<'db> KnownInstanceType<'db> {
74997511
fn normalized_impl(self, db: &'db dyn Db, visitor: &NormalizedVisitor<'db>) -> Self {
75007512
match self {
@@ -10693,14 +10705,10 @@ impl<'db> PEP695TypeAliasType<'db> {
1069310705
semantic_index(db, scope.file(db)).expect_single_definition(type_alias_stmt_node)
1069410706
}
1069510707

10708+
/// The RHS type of a PEP-695 style type alias with specialization applied.
1069610709
#[salsa::tracked(cycle_fn=value_type_cycle_recover, cycle_initial=value_type_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
1069710710
pub(crate) fn value_type(self, db: &'db dyn Db) -> Type<'db> {
10698-
let scope = self.rhs_scope(db);
10699-
let module = parsed_module(db, scope.file(db)).load(db);
10700-
let type_alias_stmt_node = scope.node(db).expect_type_alias();
10701-
let definition = self.definition(db);
10702-
let value_type =
10703-
definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value);
10711+
let value_type = self.raw_value_type(db);
1070410712

1070510713
if let Some(generic_context) = self.generic_context(db) {
1070610714
let specialization = self
@@ -10713,6 +10721,25 @@ impl<'db> PEP695TypeAliasType<'db> {
1071310721
}
1071410722
}
1071510723

10724+
/// The RHS type of a PEP-695 style type alias with *no* specialization applied.
10725+
///
10726+
/// ## Warning
10727+
///
10728+
/// This uses the semantic index to find the definition of the type alias. This means that if the
10729+
/// calling query is not in the same file as this type alias is defined in, then this will create
10730+
/// a cross-module dependency directly on the full AST which will lead to cache
10731+
/// over-invalidation.
10732+
/// This method also calls the type inference functions, and since type aliases can have recursive structures,
10733+
/// we should be careful not to create infinite recursions in this method (or make it tracked if necessary).
10734+
pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> {
10735+
let scope = self.rhs_scope(db);
10736+
let module = parsed_module(db, scope.file(db)).load(db);
10737+
let type_alias_stmt_node = scope.node(db).expect_type_alias();
10738+
let definition = self.definition(db);
10739+
10740+
definition_expression_type(db, definition, &type_alias_stmt_node.node(&module).value)
10741+
}
10742+
1071610743
pub(crate) fn apply_specialization(
1071710744
self,
1071810745
db: &'db dyn Db,
@@ -10892,6 +10919,13 @@ impl<'db> TypeAliasType<'db> {
1089210919
}
1089310920
}
1089410921

10922+
pub(crate) fn raw_value_type(self, db: &'db dyn Db) -> Type<'db> {
10923+
match self {
10924+
TypeAliasType::PEP695(type_alias) => type_alias.raw_value_type(db),
10925+
TypeAliasType::ManualPEP695(type_alias) => type_alias.value(db),
10926+
}
10927+
}
10928+
1089510929
pub(crate) fn as_pep_695_type_alias(self) -> Option<PEP695TypeAliasType<'db>> {
1089610930
match self {
1089710931
TypeAliasType::PEP695(type_alias) => Some(type_alias),
@@ -11724,4 +11758,85 @@ pub(crate) mod tests {
1172411758
.build();
1172511759
assert_eq!(intersection.display(&db).to_string(), "Never");
1172611760
}
11761+
11762+
#[test]
11763+
fn type_alias_variance() {
11764+
use crate::db::tests::TestDb;
11765+
use crate::place::global_symbol;
11766+
11767+
fn get_type_alias<'db>(db: &'db TestDb, name: &str) -> PEP695TypeAliasType<'db> {
11768+
let module = ruff_db::files::system_path_to_file(db, "/src/a.py").unwrap();
11769+
let ty = global_symbol(db, module, name).place.expect_type();
11770+
let Type::KnownInstance(KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(
11771+
type_alias,
11772+
))) = ty
11773+
else {
11774+
panic!("Expected `{name}` to be a type alias");
11775+
};
11776+
type_alias
11777+
}
11778+
fn get_bound_typevar<'db>(
11779+
db: &'db TestDb,
11780+
type_alias: PEP695TypeAliasType<'db>,
11781+
) -> BoundTypeVarInstance<'db> {
11782+
let generic_context = type_alias.generic_context(db).unwrap();
11783+
generic_context.variables(db).next().unwrap()
11784+
}
11785+
11786+
let mut db = setup_db();
11787+
db.write_dedented(
11788+
"/src/a.py",
11789+
r#"
11790+
class Covariant[T]:
11791+
def get(self) -> T:
11792+
raise ValueError
11793+
11794+
class Contravariant[T]:
11795+
def set(self, value: T):
11796+
pass
11797+
11798+
class Invariant[T]:
11799+
def get(self) -> T:
11800+
raise ValueError
11801+
def set(self, value: T):
11802+
pass
11803+
11804+
class Bivariant[T]:
11805+
pass
11806+
11807+
type CovariantAlias[T] = Covariant[T]
11808+
type ContravariantAlias[T] = Contravariant[T]
11809+
type InvariantAlias[T] = Invariant[T]
11810+
type BivariantAlias[T] = Bivariant[T]
11811+
"#,
11812+
)
11813+
.unwrap();
11814+
let covariant = get_type_alias(&db, "CovariantAlias");
11815+
assert_eq!(
11816+
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(covariant))
11817+
.variance_of(&db, get_bound_typevar(&db, covariant)),
11818+
TypeVarVariance::Covariant
11819+
);
11820+
11821+
let contravariant = get_type_alias(&db, "ContravariantAlias");
11822+
assert_eq!(
11823+
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(contravariant))
11824+
.variance_of(&db, get_bound_typevar(&db, contravariant)),
11825+
TypeVarVariance::Contravariant
11826+
);
11827+
11828+
let invariant = get_type_alias(&db, "InvariantAlias");
11829+
assert_eq!(
11830+
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(invariant))
11831+
.variance_of(&db, get_bound_typevar(&db, invariant)),
11832+
TypeVarVariance::Invariant
11833+
);
11834+
11835+
let bivariant = get_type_alias(&db, "BivariantAlias");
11836+
assert_eq!(
11837+
KnownInstanceType::TypeAliasType(TypeAliasType::PEP695(bivariant))
11838+
.variance_of(&db, get_bound_typevar(&db, bivariant)),
11839+
TypeVarVariance::Bivariant
11840+
);
11841+
}
1172711842
}

0 commit comments

Comments
 (0)