@@ -26,7 +26,9 @@ use crate::types::enums::is_enum_class;
2626use crate :: types:: function:: {
2727 DataclassTransformerParams , FunctionDecorators , FunctionType , KnownFunction , OverloadLiteral ,
2828} ;
29- use crate :: types:: generics:: { Specialization , SpecializationBuilder , SpecializationError } ;
29+ use crate :: types:: generics:: {
30+ InferableTypeVars , Specialization , SpecializationBuilder , SpecializationError ,
31+ } ;
3032use crate :: types:: signatures:: { Parameter , ParameterForm , ParameterKind , Parameters } ;
3133use crate :: types:: tuple:: { TupleLength , TupleType } ;
3234use crate :: types:: {
@@ -597,7 +599,8 @@ impl<'db> Bindings<'db> {
597599 Type :: FunctionLiteral ( function_type) => match function_type. known ( db) {
598600 Some ( KnownFunction :: IsEquivalentTo ) => {
599601 if let [ Some ( ty_a) , Some ( ty_b) ] = overload. parameter_types ( ) {
600- let constraints = ty_a. when_equivalent_to ( db, * ty_b) ;
602+ let constraints =
603+ ty_a. when_equivalent_to ( db, * ty_b, InferableTypeVars :: None ) ;
601604 let tracked = TrackedConstraintSet :: new ( db, constraints) ;
602605 overload. set_return_type ( Type :: KnownInstance (
603606 KnownInstanceType :: ConstraintSet ( tracked) ,
@@ -607,7 +610,8 @@ impl<'db> Bindings<'db> {
607610
608611 Some ( KnownFunction :: IsSubtypeOf ) => {
609612 if let [ Some ( ty_a) , Some ( ty_b) ] = overload. parameter_types ( ) {
610- let constraints = ty_a. when_subtype_of ( db, * ty_b) ;
613+ let constraints =
614+ ty_a. when_subtype_of ( db, * ty_b, InferableTypeVars :: None ) ;
611615 let tracked = TrackedConstraintSet :: new ( db, constraints) ;
612616 overload. set_return_type ( Type :: KnownInstance (
613617 KnownInstanceType :: ConstraintSet ( tracked) ,
@@ -617,7 +621,8 @@ impl<'db> Bindings<'db> {
617621
618622 Some ( KnownFunction :: IsAssignableTo ) => {
619623 if let [ Some ( ty_a) , Some ( ty_b) ] = overload. parameter_types ( ) {
620- let constraints = ty_a. when_assignable_to ( db, * ty_b) ;
624+ let constraints =
625+ ty_a. when_assignable_to ( db, * ty_b, InferableTypeVars :: None ) ;
621626 let tracked = TrackedConstraintSet :: new ( db, constraints) ;
622627 overload. set_return_type ( Type :: KnownInstance (
623628 KnownInstanceType :: ConstraintSet ( tracked) ,
@@ -627,7 +632,8 @@ impl<'db> Bindings<'db> {
627632
628633 Some ( KnownFunction :: IsDisjointFrom ) => {
629634 if let [ Some ( ty_a) , Some ( ty_b) ] = overload. parameter_types ( ) {
630- let constraints = ty_a. when_disjoint_from ( db, * ty_b) ;
635+ let constraints =
636+ ty_a. when_disjoint_from ( db, * ty_b, InferableTypeVars :: None ) ;
631637 let tracked = TrackedConstraintSet :: new ( db, constraints) ;
632638 overload. set_return_type ( Type :: KnownInstance (
633639 KnownInstanceType :: ConstraintSet ( tracked) ,
@@ -1407,7 +1413,10 @@ impl<'db> CallableBinding<'db> {
14071413 let parameter_type = overload. signature . parameters ( ) [ * parameter_index]
14081414 . annotated_type ( )
14091415 . unwrap_or ( Type :: unknown ( ) ) ;
1410- if argument_type. is_assignable_to ( db, parameter_type) {
1416+ if argument_type
1417+ . when_assignable_to ( db, parameter_type, overload. inferable_typevars )
1418+ . is_always_satisfied ( )
1419+ {
14111420 is_argument_assignable_to_any_overload = true ;
14121421 break ' overload;
14131422 }
@@ -1633,7 +1642,14 @@ impl<'db> CallableBinding<'db> {
16331642 . unwrap_or ( Type :: unknown ( ) ) ;
16341643 let first_parameter_type = & mut first_parameter_types[ parameter_index] ;
16351644 if let Some ( first_parameter_type) = first_parameter_type {
1636- if !first_parameter_type. is_equivalent_to ( db, current_parameter_type) {
1645+ if !first_parameter_type
1646+ . when_equivalent_to (
1647+ db,
1648+ current_parameter_type,
1649+ overload. inferable_typevars ,
1650+ )
1651+ . is_always_satisfied ( )
1652+ {
16371653 participating_parameter_indexes. insert ( parameter_index) ;
16381654 }
16391655 } else {
@@ -1750,7 +1766,12 @@ impl<'db> CallableBinding<'db> {
17501766 matching_overloads. all ( |( _, overload) | {
17511767 overload
17521768 . return_type ( )
1753- . is_equivalent_to ( db, first_overload_return_type)
1769+ . when_equivalent_to (
1770+ db,
1771+ first_overload_return_type,
1772+ overload. inferable_typevars ,
1773+ )
1774+ . is_always_satisfied ( )
17541775 } )
17551776 } else {
17561777 // No matching overload
@@ -2461,6 +2482,7 @@ struct ArgumentTypeChecker<'a, 'db> {
24612482 call_expression_tcx : & ' a TypeContext < ' db > ,
24622483 errors : & ' a mut Vec < BindingError < ' db > > ,
24632484
2485+ inferable_typevars : InferableTypeVars < ' db , ' db > ,
24642486 specialization : Option < Specialization < ' db > > ,
24652487}
24662488
@@ -2482,6 +2504,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
24822504 parameter_tys,
24832505 call_expression_tcx,
24842506 errors,
2507+ inferable_typevars : InferableTypeVars :: None ,
24852508 specialization : None ,
24862509 }
24872510 }
@@ -2514,11 +2537,12 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25142537 }
25152538
25162539 fn infer_specialization ( & mut self ) {
2517- if self . signature . generic_context . is_none ( ) {
2540+ let Some ( generic_context ) = self . signature . generic_context else {
25182541 return ;
2519- }
2542+ } ;
25202543
2521- let mut builder = SpecializationBuilder :: new ( self . db ) ;
2544+ // TODO: Use the list of inferable typevars from the generic context of the callable.
2545+ let mut builder = SpecializationBuilder :: new ( self . db , self . inferable_typevars ) ;
25222546
25232547 // Note that we infer the annotated type _before_ the arguments if this call is part of
25242548 // an annotated assignment, to closer match the order of any unions written in the type
@@ -2563,10 +2587,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25632587 }
25642588 }
25652589
2566- self . specialization = self
2567- . signature
2568- . generic_context
2569- . map ( |gc| builder. build ( gc, * self . call_expression_tcx ) ) ;
2590+ self . specialization = Some ( builder. build ( generic_context, * self . call_expression_tcx ) ) ;
25702591 }
25712592
25722593 fn check_argument_type (
@@ -2590,7 +2611,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25902611 // constraint set that we get from this assignability check, instead of inferring and
25912612 // building them in an earlier separate step.
25922613 if argument_type
2593- . when_assignable_to ( self . db , expected_ty)
2614+ . when_assignable_to ( self . db , expected_ty, self . inferable_typevars )
25942615 . is_never_satisfied ( )
25952616 {
25962617 let positional = matches ! ( argument, Argument :: Positional | Argument :: Synthetic )
@@ -2719,7 +2740,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27192740 return ;
27202741 } ;
27212742
2722- if !key_type. is_assignable_to ( self . db , KnownClass :: Str . to_instance ( self . db ) ) {
2743+ if !key_type
2744+ . when_assignable_to (
2745+ self . db ,
2746+ KnownClass :: Str . to_instance ( self . db ) ,
2747+ self . inferable_typevars ,
2748+ )
2749+ . is_always_satisfied ( )
2750+ {
27232751 self . errors . push ( BindingError :: InvalidKeyType {
27242752 argument_index : adjusted_argument_index,
27252753 provided_ty : key_type,
@@ -2754,8 +2782,8 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27542782 }
27552783 }
27562784
2757- fn finish ( self ) -> Option < Specialization < ' db > > {
2758- self . specialization
2785+ fn finish ( self ) -> ( InferableTypeVars < ' db , ' db > , Option < Specialization < ' db > > ) {
2786+ ( self . inferable_typevars , self . specialization )
27592787 }
27602788}
27612789
@@ -2819,6 +2847,9 @@ pub(crate) struct Binding<'db> {
28192847 /// Return type of the call.
28202848 return_ty : Type < ' db > ,
28212849
2850+ /// The inferable typevars in this signature.
2851+ inferable_typevars : InferableTypeVars < ' db , ' db > ,
2852+
28222853 /// The specialization that was inferred from the argument types, if the callable is generic.
28232854 specialization : Option < Specialization < ' db > > ,
28242855
@@ -2845,6 +2876,7 @@ impl<'db> Binding<'db> {
28452876 callable_type : signature_type,
28462877 signature_type,
28472878 return_ty : Type :: unknown ( ) ,
2879+ inferable_typevars : InferableTypeVars :: None ,
28482880 specialization : None ,
28492881 argument_matches : Box :: from ( [ ] ) ,
28502882 variadic_argument_matched_to_variadic_parameter : false ,
@@ -2916,7 +2948,7 @@ impl<'db> Binding<'db> {
29162948 checker. infer_specialization ( ) ;
29172949
29182950 checker. check_argument_types ( ) ;
2919- self . specialization = checker. finish ( ) ;
2951+ ( self . inferable_typevars , self . specialization ) = checker. finish ( ) ;
29202952 if let Some ( specialization) = self . specialization {
29212953 self . return_ty = self . return_ty . apply_specialization ( db, specialization) ;
29222954 }
@@ -3010,6 +3042,7 @@ impl<'db> Binding<'db> {
30103042 fn snapshot ( & self ) -> BindingSnapshot < ' db > {
30113043 BindingSnapshot {
30123044 return_ty : self . return_ty ,
3045+ inferable_typevars : self . inferable_typevars ,
30133046 specialization : self . specialization ,
30143047 argument_matches : self . argument_matches . clone ( ) ,
30153048 parameter_tys : self . parameter_tys . clone ( ) ,
@@ -3020,13 +3053,15 @@ impl<'db> Binding<'db> {
30203053 fn restore ( & mut self , snapshot : BindingSnapshot < ' db > ) {
30213054 let BindingSnapshot {
30223055 return_ty,
3056+ inferable_typevars,
30233057 specialization,
30243058 argument_matches,
30253059 parameter_tys,
30263060 errors,
30273061 } = snapshot;
30283062
30293063 self . return_ty = return_ty;
3064+ self . inferable_typevars = inferable_typevars;
30303065 self . specialization = specialization;
30313066 self . argument_matches = argument_matches;
30323067 self . parameter_tys = parameter_tys;
@@ -3046,6 +3081,7 @@ impl<'db> Binding<'db> {
30463081 /// Resets the state of this binding to its initial state.
30473082 fn reset ( & mut self ) {
30483083 self . return_ty = Type :: unknown ( ) ;
3084+ self . inferable_typevars = InferableTypeVars :: None ;
30493085 self . specialization = None ;
30503086 self . argument_matches = Box :: from ( [ ] ) ;
30513087 self . parameter_tys = Box :: from ( [ ] ) ;
@@ -3056,6 +3092,7 @@ impl<'db> Binding<'db> {
30563092#[ derive( Clone , Debug ) ]
30573093struct BindingSnapshot < ' db > {
30583094 return_ty : Type < ' db > ,
3095+ inferable_typevars : InferableTypeVars < ' db , ' db > ,
30593096 specialization : Option < Specialization < ' db > > ,
30603097 argument_matches : Box < [ MatchedArgument < ' db > ] > ,
30613098 parameter_tys : Box < [ Option < Type < ' db > > ] > ,
@@ -3095,6 +3132,7 @@ impl<'db> CallableBindingSnapshot<'db> {
30953132
30963133 // ... and update the snapshot with the current state of the binding.
30973134 snapshot. return_ty = binding. return_ty ;
3135+ snapshot. inferable_typevars = binding. inferable_typevars ;
30983136 snapshot. specialization = binding. specialization ;
30993137 snapshot
31003138 . argument_matches
0 commit comments