@@ -52,9 +52,10 @@ use crate::types::enums::{enum_metadata, is_single_member_enum};
5252use crate :: types:: function:: {
5353 DataclassTransformerParams , FunctionSpans , FunctionType , KnownFunction ,
5454} ;
55+ pub ( crate ) use crate :: types:: generics:: GenericContext ;
5556use crate :: types:: generics:: {
56- GenericContext , InferableTypeVars , PartialSpecialization , Specialization , bind_typevar,
57- typing_self , walk_generic_context,
57+ InferableTypeVars , PartialSpecialization , Specialization , bind_typevar, typing_self ,
58+ walk_generic_context,
5859} ;
5960use crate :: types:: infer:: infer_unpack_types;
6061use crate :: types:: mro:: { Mro , MroError , MroIterator } ;
@@ -7294,6 +7295,7 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
72947295 . collect ( ) ,
72957296 Type :: SubclassOf ( subclass_of_type) => subclass_of_type. variance_of ( db, typevar) ,
72967297 Type :: TypeIs ( type_is_type) => type_is_type. variance_of ( db, typevar) ,
7298+ Type :: KnownInstance ( known_instance) => known_instance. variance_of ( db, typevar) ,
72977299 Type :: Dynamic ( _)
72987300 | Type :: Never
72997301 | Type :: WrapperDescriptor ( _)
@@ -7308,7 +7310,6 @@ impl<'db> VarianceInferable<'db> for Type<'db> {
73087310 | Type :: LiteralString
73097311 | Type :: BytesLiteral ( _)
73107312 | Type :: SpecialForm ( _)
7311- | Type :: KnownInstance ( _)
73127313 | Type :: AlwaysFalsy
73137314 | Type :: AlwaysTruthy
73147315 | Type :: BoundSuper ( _)
@@ -7528,6 +7529,17 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
75287529 }
75297530}
75307531
7532+ impl < ' db > VarianceInferable < ' db > for KnownInstanceType < ' db > {
7533+ fn variance_of ( self , db : & ' db dyn Db , typevar : BoundTypeVarInstance < ' db > ) -> TypeVarVariance {
7534+ match self {
7535+ KnownInstanceType :: TypeAliasType ( type_alias) => {
7536+ type_alias. raw_value_type ( db) . variance_of ( db, typevar)
7537+ }
7538+ _ => TypeVarVariance :: Bivariant ,
7539+ }
7540+ }
7541+ }
7542+
75317543impl < ' db > KnownInstanceType < ' db > {
75327544 fn normalized_impl ( self , db : & ' db dyn Db , visitor : & NormalizedVisitor < ' db > ) -> Self {
75337545 match self {
@@ -10814,12 +10826,7 @@ impl<'db> PEP695TypeAliasType<'db> {
1081410826
1081510827 #[ salsa:: tracked( cycle_fn=value_type_cycle_recover, cycle_initial=value_type_cycle_initial, heap_size=ruff_memory_usage:: heap_size) ]
1081610828 pub ( crate ) fn value_type ( self , db : & ' db dyn Db ) -> Type < ' db > {
10817- let scope = self . rhs_scope ( db) ;
10818- let module = parsed_module ( db, scope. file ( db) ) . load ( db) ;
10819- let type_alias_stmt_node = scope. node ( db) . expect_type_alias ( ) ;
10820- let definition = self . definition ( db) ;
10821- let value_type =
10822- definition_expression_type ( db, definition, & type_alias_stmt_node. node ( & module) . value ) ;
10829+ let value_type = self . raw_value_type ( db) ;
1082310830
1082410831 if let Some ( generic_context) = self . generic_context ( db) {
1082510832 let specialization = self
@@ -10832,6 +10839,15 @@ impl<'db> PEP695TypeAliasType<'db> {
1083210839 }
1083310840 }
1083410841
10842+ pub ( crate ) fn raw_value_type ( self , db : & ' db dyn Db ) -> Type < ' db > {
10843+ let scope = self . rhs_scope ( db) ;
10844+ let module = parsed_module ( db, scope. file ( db) ) . load ( db) ;
10845+ let type_alias_stmt_node = scope. node ( db) . expect_type_alias ( ) ;
10846+ let definition = self . definition ( db) ;
10847+
10848+ definition_expression_type ( db, definition, & type_alias_stmt_node. node ( & module) . value )
10849+ }
10850+
1083510851 pub ( crate ) fn apply_specialization (
1083610852 self ,
1083710853 db : & ' db dyn Db ,
@@ -11011,6 +11027,13 @@ impl<'db> TypeAliasType<'db> {
1101111027 }
1101211028 }
1101311029
11030+ pub ( crate ) fn raw_value_type ( self , db : & ' db dyn Db ) -> Type < ' db > {
11031+ match self {
11032+ TypeAliasType :: PEP695 ( type_alias) => type_alias. raw_value_type ( db) ,
11033+ TypeAliasType :: ManualPEP695 ( type_alias) => type_alias. value ( db) ,
11034+ }
11035+ }
11036+
1101411037 pub ( crate ) fn as_pep_695_type_alias ( self ) -> Option < PEP695TypeAliasType < ' db > > {
1101511038 match self {
1101611039 TypeAliasType :: PEP695 ( type_alias) => Some ( type_alias) ,
@@ -11843,4 +11866,85 @@ pub(crate) mod tests {
1184311866 . build ( ) ;
1184411867 assert_eq ! ( intersection. display( & db) . to_string( ) , "Never" ) ;
1184511868 }
11869+
11870+ #[ test]
11871+ fn type_alias_variance ( ) {
11872+ use crate :: db:: tests:: TestDb ;
11873+ use crate :: place:: global_symbol;
11874+
11875+ fn get_type_alias < ' db > ( db : & ' db TestDb , name : & str ) -> PEP695TypeAliasType < ' db > {
11876+ let module = ruff_db:: files:: system_path_to_file ( db, "/src/a.py" ) . unwrap ( ) ;
11877+ let ty = global_symbol ( db, module, name) . place . expect_type ( ) ;
11878+ let Type :: KnownInstance ( KnownInstanceType :: TypeAliasType ( TypeAliasType :: PEP695 (
11879+ type_alias,
11880+ ) ) ) = ty
11881+ else {
11882+ panic ! ( "Expected `{name}` to be a type alias" ) ;
11883+ } ;
11884+ type_alias
11885+ }
11886+ fn get_bound_typevar < ' db > (
11887+ db : & ' db TestDb ,
11888+ type_alias : PEP695TypeAliasType < ' db > ,
11889+ ) -> BoundTypeVarInstance < ' db > {
11890+ let generic_context = type_alias. generic_context ( db) . unwrap ( ) ;
11891+ generic_context. variables ( db) . next ( ) . unwrap ( )
11892+ }
11893+
11894+ let mut db = setup_db ( ) ;
11895+ db. write_dedented (
11896+ "/src/a.py" ,
11897+ r#"
11898+ class Covariant[T]:
11899+ def get(self) -> T:
11900+ raise ValueError
11901+
11902+ class Contravariant[T]:
11903+ def set(self, value: T):
11904+ pass
11905+
11906+ class Invariant[T]:
11907+ def get(self) -> T:
11908+ raise ValueError
11909+ def set(self, value: T):
11910+ pass
11911+
11912+ class Bivariant[T]:
11913+ pass
11914+
11915+ type CovariantAlias[T] = Covariant[T]
11916+ type ContravariantAlias[T] = Contravariant[T]
11917+ type InvariantAlias[T] = Invariant[T]
11918+ type BivariantAlias[T] = Bivariant[T]
11919+ "# ,
11920+ )
11921+ . unwrap ( ) ;
11922+ let covariant = get_type_alias ( & db, "CovariantAlias" ) ;
11923+ assert_eq ! (
11924+ KnownInstanceType :: TypeAliasType ( TypeAliasType :: PEP695 ( covariant) )
11925+ . variance_of( & db, get_bound_typevar( & db, covariant) ) ,
11926+ TypeVarVariance :: Covariant
11927+ ) ;
11928+
11929+ let contravariant = get_type_alias ( & db, "ContravariantAlias" ) ;
11930+ assert_eq ! (
11931+ KnownInstanceType :: TypeAliasType ( TypeAliasType :: PEP695 ( contravariant) )
11932+ . variance_of( & db, get_bound_typevar( & db, contravariant) ) ,
11933+ TypeVarVariance :: Contravariant
11934+ ) ;
11935+
11936+ let invariant = get_type_alias ( & db, "InvariantAlias" ) ;
11937+ assert_eq ! (
11938+ KnownInstanceType :: TypeAliasType ( TypeAliasType :: PEP695 ( invariant) )
11939+ . variance_of( & db, get_bound_typevar( & db, invariant) ) ,
11940+ TypeVarVariance :: Invariant
11941+ ) ;
11942+
11943+ let bivariant = get_type_alias ( & db, "BivariantAlias" ) ;
11944+ assert_eq ! (
11945+ KnownInstanceType :: TypeAliasType ( TypeAliasType :: PEP695 ( bivariant) )
11946+ . variance_of( & db, get_bound_typevar( & db, bivariant) ) ,
11947+ TypeVarVariance :: Bivariant
11948+ ) ;
11949+ }
1184611950}
0 commit comments