@@ -357,7 +357,9 @@ def __init__(
357357 self ._arg_infer_context_cache = None
358358
359359 self .overload_stack_depth = 0
360- self ._args_cache = {}
360+ self .ops_stack_depth = 0
361+ self ._args_cache : dict [tuple [int , ...], list [Type ]] = {}
362+ self ._ops_cache : dict [int , Type ] = {}
361363
362364 def reset (self ) -> None :
363365 self .resolved_type = {}
@@ -1616,7 +1618,7 @@ def check_call(
16161618 object_type ,
16171619 )
16181620 elif isinstance (callee , Overloaded ):
1619- with self .overload_context (callee . name () ):
1621+ with self .overload_context ():
16201622 return self .check_overload_call (
16211623 callee , args , arg_kinds , arg_names , callable_name , object_type , context
16221624 )
@@ -1679,13 +1681,21 @@ def check_call(
16791681 return self .msg .not_callable (callee , context ), AnyType (TypeOfAny .from_error )
16801682
16811683 @contextmanager
1682- def overload_context (self , fn ) :
1684+ def overload_context (self ) -> Iterator [ None ] :
16831685 self .overload_stack_depth += 1
16841686 yield
16851687 self .overload_stack_depth -= 1
16861688 if self .overload_stack_depth == 0 :
16871689 self ._args_cache .clear ()
16881690
1691+ @contextmanager
1692+ def ops_context (self ) -> Iterator [None ]:
1693+ self .ops_stack_depth += 1
1694+ yield
1695+ self .ops_stack_depth -= 1
1696+ if self .ops_stack_depth == 0 :
1697+ self ._ops_cache .clear ()
1698+
16891699 def check_callable_call (
16901700 self ,
16911701 callee : CallableType ,
@@ -3493,8 +3503,8 @@ def visit_op_expr(self, e: OpExpr) -> Type:
34933503 return self .strfrm_checker .check_str_interpolation (e .left , e .right )
34943504
34953505 key = id (e )
3496- if key in self ._args_cache :
3497- return self ._args_cache [key ]
3506+ if key in self ._ops_cache :
3507+ return self ._ops_cache [key ]
34983508 left_type = self .accept (e .left )
34993509
35003510 proper_left_type = get_proper_type (left_type )
@@ -3564,7 +3574,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35643574 )
35653575
35663576 if e .op in operators .op_methods :
3567- with self .overload_context ( e . op ):
3577+ with self .ops_context ( ):
35683578 method = operators .op_methods [e .op ]
35693579 if use_reverse is UseReverse .DEFAULT or use_reverse is UseReverse .NEVER :
35703580 result , method_type = self .check_op (
@@ -3586,7 +3596,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35863596 else :
35873597 assert_never (use_reverse )
35883598 e .method_type = method_type
3589- self ._args_cache [key ] = result
3599+ self ._ops_cache [key ] = result
35903600 return result
35913601 else :
35923602 raise RuntimeError (f"Unknown operator { e .op } " )
0 commit comments