@@ -400,27 +400,28 @@ def functools_lru_cache_callback(ctx: mypy.plugin.FunctionContext) -> Type:
400400 """Infer a more precise return type for functools.lru_cache decorator"""
401401 if not isinstance (ctx .api , mypy .checker .TypeChecker ): # use internals
402402 return ctx .default_return_type
403-
403+
404404 # Only handle the simple case: @lru_cache (without parentheses)
405405 # where a function is passed directly as the first argument
406- if (len ( ctx . arg_types ) >= 1 and
407- len (ctx .arg_types [0 ]) == 1 and
408- len ( ctx . arg_types ) <= 2 ): # Ensure we don't have extra args indicating parameterized call
409-
406+ if (
407+ len (ctx .arg_types ) >= 1 and len ( ctx . arg_types [0 ]) == 1 and len ( ctx . arg_types ) <= 2
408+ ): # Ensure we don't have extra args indicating parameterized call
409+
410410 first_arg_type = ctx .arg_types [0 ][0 ]
411-
411+
412412 # Explicitly check that this is NOT a literal or other non-function type
413- from mypy .types import LiteralType , Instance
413+ from mypy .types import Instance , LiteralType
414+
414415 if isinstance (first_arg_type , (LiteralType , Instance )):
415416 # This is likely maxsize=128 or similar - let MyPy handle it
416417 return ctx .default_return_type
417-
418+
418419 # Try to extract callable type
419420 fn_type = ctx .api .extract_callable_type (first_arg_type , ctx = ctx .default_return_type )
420421 if fn_type is not None :
421422 # This is the @lru_cache case (function passed directly)
422423 return fn_type
423-
424+
424425 # For all parameterized cases, don't interfere
425426 return ctx .default_return_type
426427
@@ -429,17 +430,17 @@ def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
429430 """Handle calls to functools._lru_cache_wrapper objects to provide parameter validation"""
430431 if not isinstance (ctx .api , mypy .checker .TypeChecker ):
431432 return ctx .default_return_type
432-
433+
433434 # Try to find the original function signature using AST/symbol table analysis
434435 original_signature = _find_original_function_signature (ctx )
435-
436+
436437 if original_signature is not None :
437438 # Validate the call against the original function signature
438439 actual_args = []
439440 actual_arg_kinds = []
440441 actual_arg_names = []
441442 seen_args = set ()
442-
443+
443444 for i , param in enumerate (ctx .args ):
444445 for j , a in enumerate (param ):
445446 if a in seen_args :
@@ -458,55 +459,55 @@ def lru_cache_wrapper_call_callback(ctx: mypy.plugin.MethodContext) -> Type:
458459 context = ctx .context ,
459460 )
460461 return result
461-
462+
462463 return ctx .default_return_type
463464
464465
465466def _find_original_function_signature (ctx : mypy .plugin .MethodContext ) -> CallableType | None :
466467 """
467468 Attempt to find the original function signature from the call context.
468-
469+
469470 Returns the CallableType of the original function if found, None otherwise.
470471 This function safely traverses the AST structure to locate the original
471472 function signature that was decorated with @lru_cache.
472473 """
473474 from mypy .nodes import CallExpr , Decorator , NameExpr
474-
475+
475476 # Ensure we have the required context structure
476477 if not isinstance (ctx .context , CallExpr ):
477478 return None
478-
479+
479480 callee = ctx .context .callee
480481 if not isinstance (callee , NameExpr ) or not callee .name :
481482 return None
482-
483+
483484 func_name = callee .name
484-
485+
485486 # Safely access the API globals
486- if not hasattr (ctx .api , ' globals' ) or not isinstance (ctx .api .globals , dict ):
487+ if not hasattr (ctx .api , " globals" ) or not isinstance (ctx .api .globals , dict ):
487488 return None
488-
489+
489490 if func_name not in ctx .api .globals :
490491 return None
491-
492+
492493 symbol = ctx .api .globals [func_name ]
493-
494+
494495 # Validate symbol structure before accessing node
495- if not hasattr (symbol , ' node' ) or symbol .node is None :
496+ if not hasattr (symbol , " node" ) or symbol .node is None :
496497 return None
497-
498+
498499 # Check if this is a decorator node containing our function
499500 if isinstance (symbol .node , Decorator ):
500501 decorator_node = symbol .node
501-
502+
502503 # Safely access the decorated function
503- if not hasattr (decorator_node , ' func' ) or decorator_node .func is None :
504+ if not hasattr (decorator_node , " func" ) or decorator_node .func is None :
504505 return None
505-
506+
506507 func_def = decorator_node .func
507-
508+
508509 # Verify we have a callable type
509- if hasattr (func_def , ' type' ) and isinstance (func_def .type , CallableType ):
510+ if hasattr (func_def , " type" ) and isinstance (func_def .type , CallableType ):
510511 return func_def .type
511-
512+
512513 return None
0 commit comments