@@ -351,25 +351,34 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
351351 def _process_test_function (
352352 self , node : ast .AsyncFunctionDef | ast .FunctionDef
353353 ) -> ast .AsyncFunctionDef | ast .FunctionDef :
354- if self .test_framework == "unittest" and not any (
355- isinstance (d , ast .Call ) and isinstance (d .func , ast .Name ) and d .func .id == "timeout_decorator.timeout"
356- for d in node .decorator_list
357- ):
358- timeout_decorator = ast .Call (
359- func = ast .Name (id = "timeout_decorator.timeout" , ctx = ast .Load ()),
360- args = [ast .Constant (value = 15 )],
361- keywords = [],
362- )
363- node .decorator_list .append (timeout_decorator )
354+ # Optimize the search for decorator presence
355+ if self .test_framework == "unittest" :
356+ found_timeout = False
357+ for d in node .decorator_list :
358+ # Avoid isinstance(d.func, ast.Name) if d is not ast.Call
359+ if isinstance (d , ast .Call ):
360+ f = d .func
361+ # Avoid attribute lookup if f is not ast.Name
362+ if isinstance (f , ast .Name ) and f .id == "timeout_decorator.timeout" :
363+ found_timeout = True
364+ break
365+ if not found_timeout :
366+ timeout_decorator = ast .Call (
367+ func = ast .Name (id = "timeout_decorator.timeout" , ctx = ast .Load ()),
368+ args = [ast .Constant (value = 15 )],
369+ keywords = [],
370+ )
371+ node .decorator_list .append (timeout_decorator )
364372
365373 # Initialize counter for this test function
366374 if node .name not in self .async_call_counter :
367375 self .async_call_counter [node .name ] = 0
368376
369377 new_body = []
370378
379+ # Optimize ast.walk calls inside _instrument_statement, by scanning only relevant nodes
371380 for _i , stmt in enumerate (node .body ):
372- transformed_stmt , added_env_assignment = self ._instrument_statement (stmt , node . name )
381+ transformed_stmt , added_env_assignment = self ._optimized_instrument_statement (stmt )
373382
374383 if added_env_assignment :
375384 current_call_index = self .async_call_counter [node .name ]
@@ -423,6 +432,26 @@ def _call_in_positions(self, call_node: ast.Call) -> bool:
423432
424433 return node_in_call_position (call_node , self .call_positions )
425434
435+ # Optimized version: only walk child nodes for Await
436+ def _optimized_instrument_statement (self , stmt : ast .stmt ) -> tuple [ast .stmt , bool ]:
437+ # Stack-based DFS, manual for relevant Await nodes
438+ stack = [stmt ]
439+ while stack :
440+ node = stack .pop ()
441+ # Favor direct ast.Await detection
442+ if isinstance (node , ast .Await ):
443+ val = node .value
444+ if isinstance (val , ast .Call ) and self ._is_target_call (val ) and self ._call_in_positions (val ):
445+ return stmt , True
446+ # Use _fields instead of ast.walk for less allocations
447+ for fname in getattr (node , "_fields" , ()):
448+ child = getattr (node , fname , None )
449+ if isinstance (child , list ):
450+ stack .extend (child )
451+ elif isinstance (child , ast .AST ):
452+ stack .append (child )
453+ return stmt , False
454+
426455
427456class FunctionImportedAsVisitor (ast .NodeVisitor ):
428457 """Checks if a function has been imported as an alias. We only care about the alias then.
0 commit comments