@@ -466,4 +466,82 @@ def static_method_b():
466466
467467 # Compare the modified content with expected content
468468 assert modified_content_1 .strip () == expected_content_1 .strip ()
469- assert modified_content_2 .strip () == expected_content_2 .strip ()
469+ assert modified_content_2 .strip () == expected_content_2 .strip ()
470+
471+
472+ def test_add_decorator_to_method_after_nested_class () -> None :
473+ """Test adding decorator to a method that appears after a nested class definition."""
474+ code = """
475+ class OuterClass:
476+ class NestedClass:
477+ def nested_method(self):
478+ return "Hello from nested class method"
479+
480+ def target_method(self):
481+ return "Hello from target method after nested class"
482+ """
483+
484+ fto = FunctionToOptimize (
485+ function_name = "target_method" ,
486+ file_path = Path ("dummy_path.py" ),
487+ parents = [FunctionParent (name = "OuterClass" , type = "ClassDef" )]
488+ )
489+
490+ modified_code = add_codeflash_decorator_to_code (
491+ code = code ,
492+ functions_to_optimize = [fto ]
493+ )
494+
495+ expected_code = """
496+ from codeflash.benchmarking.codeflash_trace import codeflash_trace
497+ class OuterClass:
498+ class NestedClass:
499+ def nested_method(self):
500+ return "Hello from nested class method"
501+
502+ @codeflash_trace
503+ def target_method(self):
504+ return "Hello from target method after nested class"
505+ """
506+
507+ assert modified_code .strip () == expected_code .strip ()
508+
509+
510+ def test_add_decorator_to_function_after_nested_function () -> None :
511+ """Test adding decorator to a function that appears after a function with a nested function."""
512+ code = """
513+ def function_with_nested():
514+ def inner_function():
515+ return "Hello from inner function"
516+
517+ return inner_function()
518+
519+ def target_function():
520+ return "Hello from target function after nested function"
521+ """
522+
523+ fto = FunctionToOptimize (
524+ function_name = "target_function" ,
525+ file_path = Path ("dummy_path.py" ),
526+ parents = []
527+ )
528+
529+ modified_code = add_codeflash_decorator_to_code (
530+ code = code ,
531+ functions_to_optimize = [fto ]
532+ )
533+
534+ expected_code = """
535+ from codeflash.benchmarking.codeflash_trace import codeflash_trace
536+ def function_with_nested():
537+ def inner_function():
538+ return "Hello from inner function"
539+
540+ return inner_function()
541+
542+ @codeflash_trace
543+ def target_function():
544+ return "Hello from target function after nested function"
545+ """
546+
547+ assert modified_code .strip () == expected_code .strip ()
0 commit comments