@@ -360,6 +360,9 @@ def __init__(
360360 ] = {}
361361 self .in_lambda_expr = False
362362
363+ self ._literal_true : Instance | None = None
364+ self ._literal_false : Instance | None = None
365+
363366 def reset (self ) -> None :
364367 self .resolved_type = {}
365368 self .expr_cache .clear ()
@@ -1072,7 +1075,7 @@ def check_typeddict_call_with_kwargs(
10721075
10731076 # We don't show any errors, just infer types in a generic TypedDict type,
10741077 # a custom error message will be given below, if there are errors.
1075- with self .msg .filter_errors (), self .chk .local_type_map () :
1078+ with self .msg .filter_errors (), self .chk .local_type_map :
10761079 orig_ret_type , _ = self .check_callable_call (
10771080 infer_callee ,
10781081 # We use first expression for each key to infer type variables of a generic
@@ -1437,7 +1440,7 @@ def is_generic_decorator_overload_call(
14371440 return None
14381441 if not isinstance (get_proper_type (callee_type .ret_type ), CallableType ):
14391442 return None
1440- with self .chk .local_type_map () :
1443+ with self .chk .local_type_map :
14411444 with self .msg .filter_errors ():
14421445 arg_type = get_proper_type (self .accept (args [0 ], type_context = None ))
14431446 if isinstance (arg_type , Overloaded ):
@@ -2719,6 +2722,7 @@ def check_overload_call(
27192722 # for example, when we have a fallback alternative that accepts an unrestricted
27202723 # typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
27212724 erased_targets : list [CallableType ] | None = None
2725+ inferred_types : list [Type ] | None = None
27222726 unioned_result : tuple [Type , Type ] | None = None
27232727
27242728 # Determine whether we need to encourage union math. This should be generally safe,
@@ -2746,13 +2750,14 @@ def check_overload_call(
27462750 # Record if we succeeded. Next we need to see if maybe normal procedure
27472751 # gives a narrower type.
27482752 if unioned_return :
2749- returns , inferred_types = zip (* unioned_return )
2753+ returns = [u [0 ] for u in unioned_return ]
2754+ inferred_types = [u [1 ] for u in unioned_return ]
27502755 # Note that we use `combine_function_signatures` instead of just returning
27512756 # a union of inferred callables because for example a call
27522757 # Union[int -> int, str -> str](Union[int, str]) is invalid and
27532758 # we don't want to introduce internal inconsistencies.
27542759 unioned_result = (
2755- make_simplified_union (list ( returns ) , context .line , context .column ),
2760+ make_simplified_union (returns , context .line , context .column ),
27562761 self .combine_function_signatures (get_proper_types (inferred_types )),
27572762 )
27582763
@@ -2767,19 +2772,26 @@ def check_overload_call(
27672772 object_type ,
27682773 context ,
27692774 )
2770- # If any of checks succeed, stop early.
2775+ # If any of checks succeed, perform deprecation tests and stop early.
27712776 if inferred_result is not None and unioned_result is not None :
27722777 # Both unioned and direct checks succeeded, choose the more precise type.
27732778 if (
27742779 is_subtype (inferred_result [0 ], unioned_result [0 ])
27752780 and not isinstance (get_proper_type (inferred_result [0 ]), AnyType )
27762781 and not none_type_var_overlap
27772782 ):
2778- return inferred_result
2779- return unioned_result
2780- elif unioned_result is not None :
2783+ unioned_result = None
2784+ else :
2785+ inferred_result = None
2786+ if unioned_result is not None :
2787+ if inferred_types is not None :
2788+ for inferred_type in inferred_types :
2789+ if isinstance (c := get_proper_type (inferred_type ), CallableType ):
2790+ self .chk .warn_deprecated (c .definition , context )
27812791 return unioned_result
2782- elif inferred_result is not None :
2792+ if inferred_result is not None :
2793+ if isinstance (c := get_proper_type (inferred_result [1 ]), CallableType ):
2794+ self .chk .warn_deprecated (c .definition , context )
27832795 return inferred_result
27842796
27852797 # Step 4: Failure. At this point, we know there is no match. We fall back to trying
@@ -2917,7 +2929,7 @@ def infer_overload_return_type(
29172929 for typ in plausible_targets :
29182930 assert self .msg is self .chk .msg
29192931 with self .msg .filter_errors () as w :
2920- with self .chk .local_type_map () as m :
2932+ with self .chk .local_type_map as m :
29212933 ret_type , infer_type = self .check_call (
29222934 callee = typ ,
29232935 args = args ,
@@ -2933,8 +2945,6 @@ def infer_overload_return_type(
29332945 # check for ambiguity due to 'Any' below.
29342946 if not args_contain_any :
29352947 self .chk .store_types (m )
2936- if isinstance (infer_type , ProperType ) and isinstance (infer_type , CallableType ):
2937- self .chk .warn_deprecated (infer_type .definition , context )
29382948 return ret_type , infer_type
29392949 p_infer_type = get_proper_type (infer_type )
29402950 if isinstance (p_infer_type , CallableType ):
@@ -2971,11 +2981,6 @@ def infer_overload_return_type(
29712981 else :
29722982 # Success! No ambiguity; return the first match.
29732983 self .chk .store_types (type_maps [0 ])
2974- inferred_callable = inferred_types [0 ]
2975- if isinstance (inferred_callable , ProperType ) and isinstance (
2976- inferred_callable , CallableType
2977- ):
2978- self .chk .warn_deprecated (inferred_callable .definition , context )
29792984 return return_types [0 ], inferred_types [0 ]
29802985
29812986 def overload_erased_call_targets (
@@ -3428,11 +3433,19 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty
34283433 if self .is_literal_context ():
34293434 return LiteralType (value = value , fallback = typ )
34303435 else :
3431- return typ .copy_modified (
3432- last_known_value = LiteralType (
3433- value = value , fallback = typ , line = typ .line , column = typ .column
3434- )
3435- )
3436+ if value is True :
3437+ if self ._literal_true is None :
3438+ self ._literal_true = typ .copy_modified (
3439+ last_known_value = LiteralType (value = value , fallback = typ )
3440+ )
3441+ return self ._literal_true
3442+ if value is False :
3443+ if self ._literal_false is None :
3444+ self ._literal_false = typ .copy_modified (
3445+ last_known_value = LiteralType (value = value , fallback = typ )
3446+ )
3447+ return self ._literal_false
3448+ return typ .copy_modified (last_known_value = LiteralType (value = value , fallback = typ ))
34363449
34373450 def concat_tuples (self , left : TupleType , right : TupleType ) -> TupleType :
34383451 """Concatenate two fixed length tuples."""
@@ -5350,20 +5363,21 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
53505363 # an error, but returns the TypedDict type that matches the literal it found
53515364 # that would cause a second error when that TypedDict type is returned upstream
53525365 # to avoid the second error, we always return TypedDict type that was requested
5353- typeddict_contexts = self .find_typeddict_context (self .type_context [- 1 ], e )
5366+ typeddict_contexts , exhaustive = self .find_typeddict_context (self .type_context [- 1 ], e )
53545367 if typeddict_contexts :
5355- if len (typeddict_contexts ) == 1 :
5368+ if len (typeddict_contexts ) == 1 and exhaustive :
53565369 return self .check_typeddict_literal_in_context (e , typeddict_contexts [0 ])
53575370 # Multiple items union, check if at least one of them matches cleanly.
53585371 for typeddict_context in typeddict_contexts :
5359- with self .msg .filter_errors () as err , self .chk .local_type_map () as tmap :
5372+ with self .msg .filter_errors () as err , self .chk .local_type_map as tmap :
53605373 ret_type = self .check_typeddict_literal_in_context (e , typeddict_context )
53615374 if err .has_new_errors ():
53625375 continue
53635376 self .chk .store_types (tmap )
53645377 return ret_type
53655378 # No item matched without an error, so we can't unambiguously choose the item.
5366- self .msg .typeddict_context_ambiguous (typeddict_contexts , e )
5379+ if exhaustive :
5380+ self .msg .typeddict_context_ambiguous (typeddict_contexts , e )
53675381
53685382 # fast path attempt
53695383 dt = self .fast_dict_type (e )
@@ -5425,22 +5439,29 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
54255439
54265440 def find_typeddict_context (
54275441 self , context : Type | None , dict_expr : DictExpr
5428- ) -> list [TypedDictType ]:
5442+ ) -> tuple [list [TypedDictType ], bool ]:
5443+ """Extract `TypedDict` members of the enclosing context.
5444+
5445+ Returns:
5446+ a 2-tuple, (found_candidates, is_exhaustive)
5447+ """
54295448 context = get_proper_type (context )
54305449 if isinstance (context , TypedDictType ):
5431- return [context ]
5450+ return [context ], True
54325451 elif isinstance (context , UnionType ):
54335452 items = []
5453+ exhaustive = True
54345454 for item in context .items :
5435- item_contexts = self .find_typeddict_context (item , dict_expr )
5455+ item_contexts , item_exhaustive = self .find_typeddict_context (item , dict_expr )
54365456 for item_context in item_contexts :
54375457 if self .match_typeddict_call_with_dict (
54385458 item_context , dict_expr .items , dict_expr
54395459 ):
54405460 items .append (item_context )
5441- return items
5461+ exhaustive = exhaustive and item_exhaustive
5462+ return items , exhaustive
54425463 # No TypedDict type in context.
5443- return []
5464+ return [], False
54445465
54455466 def visit_lambda_expr (self , e : LambdaExpr ) -> Type :
54465467 """Type check lambda expression."""
@@ -6076,15 +6097,12 @@ def accept(
60766097
60776098 def accept_maybe_cache (self , node : Expression , type_context : Type | None = None ) -> Type :
60786099 binder_version = self .chk .binder .version
6079- # Micro-optimization: inline local_type_map() as it is somewhat slow in mypyc.
6080- type_map : dict [Expression , Type ] = {}
6081- self .chk ._type_maps .append (type_map )
60826100 with self .msg .filter_errors (filter_errors = True , save_filtered_errors = True ) as msg :
6083- typ = node .accept (self )
6101+ with self .chk .local_type_map as type_map :
6102+ typ = node .accept (self )
60846103 messages = msg .filtered_errors ()
60856104 if binder_version == self .chk .binder .version and not self .chk .current_node_deferred :
60866105 self .expr_cache [(node , type_context )] = (binder_version , typ , messages , type_map )
6087- self .chk ._type_maps .pop ()
60886106 self .chk .store_types (type_map )
60896107 self .msg .add_errors (messages )
60906108 return typ
0 commit comments