diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index f93c304b1..8eb671540 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -316,6 +316,11 @@ def __init__( def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: # Add timeout decorator for unittest test classes if needed if self.test_framework == "unittest": + timeout_decorator = ast.Call( + func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), + args=[ast.Constant(value=15)], + keywords=[], + ) for item in node.body: if ( isinstance(item, ast.FunctionDef) @@ -327,11 +332,6 @@ def visit_ClassDef(self, node: ast.ClassDef) -> ast.ClassDef: for d in item.decorator_list ) ): - timeout_decorator = ast.Call( - func=ast.Name(id="timeout_decorator.timeout", ctx=ast.Load()), - args=[ast.Constant(value=15)], - keywords=[], - ) item.decorator_list.append(timeout_decorator) return self.generic_visit(node)