@@ -360,6 +360,9 @@ def __init__(
360
360
] = {}
361
361
self .in_lambda_expr = False
362
362
363
+ self ._literal_true : Instance | None = None
364
+ self ._literal_false : Instance | None = None
365
+
363
366
def reset (self ) -> None :
364
367
self .resolved_type = {}
365
368
self .expr_cache .clear ()
@@ -1072,7 +1075,7 @@ def check_typeddict_call_with_kwargs(
1072
1075
1073
1076
# We don't show any errors, just infer types in a generic TypedDict type,
1074
1077
# a custom error message will be given below, if there are errors.
1075
- with self .msg .filter_errors (), self .chk .local_type_map () :
1078
+ with self .msg .filter_errors (), self .chk .local_type_map :
1076
1079
orig_ret_type , _ = self .check_callable_call (
1077
1080
infer_callee ,
1078
1081
# We use first expression for each key to infer type variables of a generic
@@ -1437,7 +1440,7 @@ def is_generic_decorator_overload_call(
1437
1440
return None
1438
1441
if not isinstance (get_proper_type (callee_type .ret_type ), CallableType ):
1439
1442
return None
1440
- with self .chk .local_type_map () :
1443
+ with self .chk .local_type_map :
1441
1444
with self .msg .filter_errors ():
1442
1445
arg_type = get_proper_type (self .accept (args [0 ], type_context = None ))
1443
1446
if isinstance (arg_type , Overloaded ):
@@ -2719,6 +2722,7 @@ def check_overload_call(
2719
2722
# for example, when we have a fallback alternative that accepts an unrestricted
2720
2723
# typevar. See https://github.com/python/mypy/issues/4063 for related discussion.
2721
2724
erased_targets : list [CallableType ] | None = None
2725
+ inferred_types : list [Type ] | None = None
2722
2726
unioned_result : tuple [Type , Type ] | None = None
2723
2727
2724
2728
# Determine whether we need to encourage union math. This should be generally safe,
@@ -2746,13 +2750,14 @@ def check_overload_call(
2746
2750
# Record if we succeeded. Next we need to see if maybe normal procedure
2747
2751
# gives a narrower type.
2748
2752
if unioned_return :
2749
- returns , inferred_types = zip (* unioned_return )
2753
+ returns = [u [0 ] for u in unioned_return ]
2754
+ inferred_types = [u [1 ] for u in unioned_return ]
2750
2755
# Note that we use `combine_function_signatures` instead of just returning
2751
2756
# a union of inferred callables because for example a call
2752
2757
# Union[int -> int, str -> str](Union[int, str]) is invalid and
2753
2758
# we don't want to introduce internal inconsistencies.
2754
2759
unioned_result = (
2755
- make_simplified_union (list ( returns ) , context .line , context .column ),
2760
+ make_simplified_union (returns , context .line , context .column ),
2756
2761
self .combine_function_signatures (get_proper_types (inferred_types )),
2757
2762
)
2758
2763
@@ -2767,19 +2772,26 @@ def check_overload_call(
2767
2772
object_type ,
2768
2773
context ,
2769
2774
)
2770
- # If any of checks succeed, stop early.
2775
+ # If any of checks succeed, perform deprecation tests and stop early.
2771
2776
if inferred_result is not None and unioned_result is not None :
2772
2777
# Both unioned and direct checks succeeded, choose the more precise type.
2773
2778
if (
2774
2779
is_subtype (inferred_result [0 ], unioned_result [0 ])
2775
2780
and not isinstance (get_proper_type (inferred_result [0 ]), AnyType )
2776
2781
and not none_type_var_overlap
2777
2782
):
2778
- return inferred_result
2779
- return unioned_result
2780
- elif unioned_result is not None :
2783
+ unioned_result = None
2784
+ else :
2785
+ inferred_result = None
2786
+ if unioned_result is not None :
2787
+ if inferred_types is not None :
2788
+ for inferred_type in inferred_types :
2789
+ if isinstance (c := get_proper_type (inferred_type ), CallableType ):
2790
+ self .chk .warn_deprecated (c .definition , context )
2781
2791
return unioned_result
2782
- elif inferred_result is not None :
2792
+ if inferred_result is not None :
2793
+ if isinstance (c := get_proper_type (inferred_result [1 ]), CallableType ):
2794
+ self .chk .warn_deprecated (c .definition , context )
2783
2795
return inferred_result
2784
2796
2785
2797
# Step 4: Failure. At this point, we know there is no match. We fall back to trying
@@ -2917,7 +2929,7 @@ def infer_overload_return_type(
2917
2929
for typ in plausible_targets :
2918
2930
assert self .msg is self .chk .msg
2919
2931
with self .msg .filter_errors () as w :
2920
- with self .chk .local_type_map () as m :
2932
+ with self .chk .local_type_map as m :
2921
2933
ret_type , infer_type = self .check_call (
2922
2934
callee = typ ,
2923
2935
args = args ,
@@ -2933,8 +2945,6 @@ def infer_overload_return_type(
2933
2945
# check for ambiguity due to 'Any' below.
2934
2946
if not args_contain_any :
2935
2947
self .chk .store_types (m )
2936
- if isinstance (infer_type , ProperType ) and isinstance (infer_type , CallableType ):
2937
- self .chk .warn_deprecated (infer_type .definition , context )
2938
2948
return ret_type , infer_type
2939
2949
p_infer_type = get_proper_type (infer_type )
2940
2950
if isinstance (p_infer_type , CallableType ):
@@ -2971,11 +2981,6 @@ def infer_overload_return_type(
2971
2981
else :
2972
2982
# Success! No ambiguity; return the first match.
2973
2983
self .chk .store_types (type_maps [0 ])
2974
- inferred_callable = inferred_types [0 ]
2975
- if isinstance (inferred_callable , ProperType ) and isinstance (
2976
- inferred_callable , CallableType
2977
- ):
2978
- self .chk .warn_deprecated (inferred_callable .definition , context )
2979
2984
return return_types [0 ], inferred_types [0 ]
2980
2985
2981
2986
def overload_erased_call_targets (
@@ -3428,11 +3433,19 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty
3428
3433
if self .is_literal_context ():
3429
3434
return LiteralType (value = value , fallback = typ )
3430
3435
else :
3431
- return typ .copy_modified (
3432
- last_known_value = LiteralType (
3433
- value = value , fallback = typ , line = typ .line , column = typ .column
3434
- )
3435
- )
3436
+ if value is True :
3437
+ if self ._literal_true is None :
3438
+ self ._literal_true = typ .copy_modified (
3439
+ last_known_value = LiteralType (value = value , fallback = typ )
3440
+ )
3441
+ return self ._literal_true
3442
+ if value is False :
3443
+ if self ._literal_false is None :
3444
+ self ._literal_false = typ .copy_modified (
3445
+ last_known_value = LiteralType (value = value , fallback = typ )
3446
+ )
3447
+ return self ._literal_false
3448
+ return typ .copy_modified (last_known_value = LiteralType (value = value , fallback = typ ))
3436
3449
3437
3450
def concat_tuples (self , left : TupleType , right : TupleType ) -> TupleType :
3438
3451
"""Concatenate two fixed length tuples."""
@@ -5350,20 +5363,21 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
5350
5363
# an error, but returns the TypedDict type that matches the literal it found
5351
5364
# that would cause a second error when that TypedDict type is returned upstream
5352
5365
# to avoid the second error, we always return TypedDict type that was requested
5353
- typeddict_contexts = self .find_typeddict_context (self .type_context [- 1 ], e )
5366
+ typeddict_contexts , exhaustive = self .find_typeddict_context (self .type_context [- 1 ], e )
5354
5367
if typeddict_contexts :
5355
- if len (typeddict_contexts ) == 1 :
5368
+ if len (typeddict_contexts ) == 1 and exhaustive :
5356
5369
return self .check_typeddict_literal_in_context (e , typeddict_contexts [0 ])
5357
5370
# Multiple items union, check if at least one of them matches cleanly.
5358
5371
for typeddict_context in typeddict_contexts :
5359
- with self .msg .filter_errors () as err , self .chk .local_type_map () as tmap :
5372
+ with self .msg .filter_errors () as err , self .chk .local_type_map as tmap :
5360
5373
ret_type = self .check_typeddict_literal_in_context (e , typeddict_context )
5361
5374
if err .has_new_errors ():
5362
5375
continue
5363
5376
self .chk .store_types (tmap )
5364
5377
return ret_type
5365
5378
# No item matched without an error, so we can't unambiguously choose the item.
5366
- self .msg .typeddict_context_ambiguous (typeddict_contexts , e )
5379
+ if exhaustive :
5380
+ self .msg .typeddict_context_ambiguous (typeddict_contexts , e )
5367
5381
5368
5382
# fast path attempt
5369
5383
dt = self .fast_dict_type (e )
@@ -5425,22 +5439,29 @@ def visit_dict_expr(self, e: DictExpr) -> Type:
5425
5439
5426
5440
def find_typeddict_context (
5427
5441
self , context : Type | None , dict_expr : DictExpr
5428
- ) -> list [TypedDictType ]:
5442
+ ) -> tuple [list [TypedDictType ], bool ]:
5443
+ """Extract `TypedDict` members of the enclosing context.
5444
+
5445
+ Returns:
5446
+ a 2-tuple, (found_candidates, is_exhaustive)
5447
+ """
5429
5448
context = get_proper_type (context )
5430
5449
if isinstance (context , TypedDictType ):
5431
- return [context ]
5450
+ return [context ], True
5432
5451
elif isinstance (context , UnionType ):
5433
5452
items = []
5453
+ exhaustive = True
5434
5454
for item in context .items :
5435
- item_contexts = self .find_typeddict_context (item , dict_expr )
5455
+ item_contexts , item_exhaustive = self .find_typeddict_context (item , dict_expr )
5436
5456
for item_context in item_contexts :
5437
5457
if self .match_typeddict_call_with_dict (
5438
5458
item_context , dict_expr .items , dict_expr
5439
5459
):
5440
5460
items .append (item_context )
5441
- return items
5461
+ exhaustive = exhaustive and item_exhaustive
5462
+ return items , exhaustive
5442
5463
# No TypedDict type in context.
5443
- return []
5464
+ return [], False
5444
5465
5445
5466
def visit_lambda_expr (self , e : LambdaExpr ) -> Type :
5446
5467
"""Type check lambda expression."""
@@ -6076,15 +6097,12 @@ def accept(
6076
6097
6077
6098
def accept_maybe_cache (self , node : Expression , type_context : Type | None = None ) -> Type :
6078
6099
binder_version = self .chk .binder .version
6079
- # Micro-optimization: inline local_type_map() as it is somewhat slow in mypyc.
6080
- type_map : dict [Expression , Type ] = {}
6081
- self .chk ._type_maps .append (type_map )
6082
6100
with self .msg .filter_errors (filter_errors = True , save_filtered_errors = True ) as msg :
6083
- typ = node .accept (self )
6101
+ with self .chk .local_type_map as type_map :
6102
+ typ = node .accept (self )
6084
6103
messages = msg .filtered_errors ()
6085
6104
if binder_version == self .chk .binder .version and not self .chk .current_node_deferred :
6086
6105
self .expr_cache [(node , type_context )] = (binder_version , typ , messages , type_map )
6087
- self .chk ._type_maps .pop ()
6088
6106
self .chk .store_types (type_map )
6089
6107
self .msg .add_errors (messages )
6090
6108
return typ
0 commit comments