@@ -356,6 +356,9 @@ def __init__(
356356
357357 self ._arg_infer_context_cache = None
358358
359+ self .overload_stack_depth = 0
360+ self ._args_cache = {}
361+
359362 def reset (self ) -> None :
360363 self .resolved_type = {}
361364
@@ -1613,9 +1616,10 @@ def check_call(
16131616 object_type ,
16141617 )
16151618 elif isinstance (callee , Overloaded ):
1616- return self .check_overload_call (
1617- callee , args , arg_kinds , arg_names , callable_name , object_type , context
1618- )
1619+ with self .overload_context (callee .name ()):
1620+ return self .check_overload_call (
1621+ callee , args , arg_kinds , arg_names , callable_name , object_type , context
1622+ )
16191623 elif isinstance (callee , AnyType ) or not self .chk .in_checked_function ():
16201624 return self .check_any_type_call (args , callee )
16211625 elif isinstance (callee , UnionType ):
@@ -1674,6 +1678,14 @@ def check_call(
16741678 else :
16751679 return self .msg .not_callable (callee , context ), AnyType (TypeOfAny .from_error )
16761680
1681+ @contextmanager
1682+ def overload_context (self , fn ):
1683+ self .overload_stack_depth += 1
1684+ yield
1685+ self .overload_stack_depth -= 1
1686+ if self .overload_stack_depth == 0 :
1687+ self ._args_cache .clear ()
1688+
16771689 def check_callable_call (
16781690 self ,
16791691 callee : CallableType ,
@@ -1937,6 +1949,10 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
19371949 In short, we basically recurse on each argument without considering
19381950 in what context the argument was called.
19391951 """
1952+ can_cache = not any (isinstance (t , TempNode ) for t in args )
1953+ key = tuple (map (id , args ))
1954+ if can_cache and key in self ._args_cache :
1955+ return self ._args_cache [key ]
19401956 res : list [Type ] = []
19411957
19421958 for arg in args :
@@ -1945,6 +1961,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
19451961 res .append (NoneType ())
19461962 else :
19471963 res .append (arg_type )
1964+ if can_cache :
1965+ self ._args_cache [key ] = res
19481966 return res
19491967
19501968 def infer_more_unions_for_recursive_type (self , type_context : Type ) -> bool :
@@ -2917,17 +2935,16 @@ def infer_overload_return_type(
29172935
29182936 for typ in plausible_targets :
29192937 assert self .msg is self .chk .msg
2920- with self .msg .filter_errors () as w :
2921- with self .chk .local_type_map () as m :
2922- ret_type , infer_type = self .check_call (
2923- callee = typ ,
2924- args = args ,
2925- arg_kinds = arg_kinds ,
2926- arg_names = arg_names ,
2927- context = context ,
2928- callable_name = callable_name ,
2929- object_type = object_type ,
2930- )
2938+ with self .msg .filter_errors () as w , self .chk .local_type_map () as m :
2939+ ret_type , infer_type = self .check_call (
2940+ callee = typ ,
2941+ args = args ,
2942+ arg_kinds = arg_kinds ,
2943+ arg_names = arg_names ,
2944+ context = context ,
2945+ callable_name = callable_name ,
2946+ object_type = object_type ,
2947+ )
29312948 is_match = not w .has_new_errors ()
29322949 if is_match :
29332950 # Return early if possible; otherwise record info, so we can
@@ -3474,6 +3491,10 @@ def visit_op_expr(self, e: OpExpr) -> Type:
34743491 return self .strfrm_checker .check_str_interpolation (e .left , e .right )
34753492 if isinstance (e .left , StrExpr ):
34763493 return self .strfrm_checker .check_str_interpolation (e .left , e .right )
3494+
3495+ key = id (e )
3496+ if key in self ._args_cache :
3497+ return self ._args_cache [key ]
34773498 left_type = self .accept (e .left )
34783499
34793500 proper_left_type = get_proper_type (left_type )
@@ -3543,28 +3564,30 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35433564 )
35443565
35453566 if e .op in operators .op_methods :
3546- method = operators .op_methods [e .op ]
3547- if use_reverse is UseReverse .DEFAULT or use_reverse is UseReverse .NEVER :
3548- result , method_type = self .check_op (
3549- method ,
3550- base_type = left_type ,
3551- arg = e .right ,
3552- context = e ,
3553- allow_reverse = use_reverse is UseReverse .DEFAULT ,
3554- )
3555- elif use_reverse is UseReverse .ALWAYS :
3556- result , method_type = self .check_op (
3557- # The reverse operator here gives better error messages:
3558- operators .reverse_op_methods [method ],
3559- base_type = self .accept (e .right ),
3560- arg = e .left ,
3561- context = e ,
3562- allow_reverse = False ,
3563- )
3564- else :
3565- assert_never (use_reverse )
3566- e .method_type = method_type
3567- return result
3567+ with self .overload_context (e .op ):
3568+ method = operators .op_methods [e .op ]
3569+ if use_reverse is UseReverse .DEFAULT or use_reverse is UseReverse .NEVER :
3570+ result , method_type = self .check_op (
3571+ method ,
3572+ base_type = left_type ,
3573+ arg = e .right ,
3574+ context = e ,
3575+ allow_reverse = use_reverse is UseReverse .DEFAULT ,
3576+ )
3577+ elif use_reverse is UseReverse .ALWAYS :
3578+ result , method_type = self .check_op (
3579+ # The reverse operator here gives better error messages:
3580+ operators .reverse_op_methods [method ],
3581+ base_type = self .accept (e .right ),
3582+ arg = e .left ,
3583+ context = e ,
3584+ allow_reverse = False ,
3585+ )
3586+ else :
3587+ assert_never (use_reverse )
3588+ e .method_type = method_type
3589+ self ._args_cache [key ] = result
3590+ return result
35683591 else :
35693592 raise RuntimeError (f"Unknown operator { e .op } " )
35703593
0 commit comments