Skip to content

Commit 27a6488

Browse files
committed
handled edge case for instrumenting codeflash trace
1 parent b374b6e commit 27a6488

File tree

2 files changed

+91
-4
lines changed

2 files changed

+91
-4
lines changed

codeflash/benchmarking/instrument_codeflash_trace.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,38 @@ def __init__(self, target_functions: set[tuple[str, str]]) -> None:
1212
self.target_functions = target_functions
1313
self.added_codeflash_trace = False
1414
self.class_name = ""
15+
self.function_name = ""
1516
self.decorator = cst.Decorator(
1617
decorator=cst.Name(value="codeflash_trace")
1718
)
1819

1920
def leave_ClassDef(self, original_node, updated_node):
20-
self.class_name = ""
21+
if self.class_name == original_node.name.value:
22+
self.class_name = "" # Even if nested classes are not visited, this function is still called on them
2123
return updated_node
2224

2325
def visit_ClassDef(self, node):
2426
if self.class_name: # Don't go into nested class
2527
return False
2628
self.class_name = node.name.value
2729

30+
def visit_FunctionDef(self, node):
31+
if self.function_name: # Don't go into nested function
32+
return False
33+
self.function_name = node.name.value
34+
2835
def leave_FunctionDef(self, original_node, updated_node):
36+
if self.function_name == original_node.name.value:
37+
self.function_name = ""
2938
if (self.class_name, original_node.name.value) in self.target_functions:
3039
# Add the new decorator after any existing decorators, so it gets executed first
3140
updated_decorators = list(updated_node.decorators) + [self.decorator]
3241
self.added_codeflash_trace = True
3342
return updated_node.with_changes(
3443
decorators=updated_decorators
3544
)
36-
else:
37-
return updated_node
45+
46+
return updated_node
3847

3948
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
4049
# Create import statement for codeflash_trace

tests/test_instrument_codeflash_trace.py

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -466,4 +466,82 @@ def static_method_b():
466466

467467
# Compare the modified content with expected content
468468
assert modified_content_1.strip() == expected_content_1.strip()
469-
assert modified_content_2.strip() == expected_content_2.strip()
469+
assert modified_content_2.strip() == expected_content_2.strip()
470+
471+
472+
def test_add_decorator_to_method_after_nested_class() -> None:
473+
"""Test adding decorator to a method that appears after a nested class definition."""
474+
code = """
475+
class OuterClass:
476+
class NestedClass:
477+
def nested_method(self):
478+
return "Hello from nested class method"
479+
480+
def target_method(self):
481+
return "Hello from target method after nested class"
482+
"""
483+
484+
fto = FunctionToOptimize(
485+
function_name="target_method",
486+
file_path=Path("dummy_path.py"),
487+
parents=[FunctionParent(name="OuterClass", type="ClassDef")]
488+
)
489+
490+
modified_code = add_codeflash_decorator_to_code(
491+
code=code,
492+
functions_to_optimize=[fto]
493+
)
494+
495+
expected_code = """
496+
from codeflash.benchmarking.codeflash_trace import codeflash_trace
497+
class OuterClass:
498+
class NestedClass:
499+
def nested_method(self):
500+
return "Hello from nested class method"
501+
502+
@codeflash_trace
503+
def target_method(self):
504+
return "Hello from target method after nested class"
505+
"""
506+
507+
assert modified_code.strip() == expected_code.strip()
508+
509+
510+
def test_add_decorator_to_function_after_nested_function() -> None:
511+
"""Test adding decorator to a function that appears after a function with a nested function."""
512+
code = """
513+
def function_with_nested():
514+
def inner_function():
515+
return "Hello from inner function"
516+
517+
return inner_function()
518+
519+
def target_function():
520+
return "Hello from target function after nested function"
521+
"""
522+
523+
fto = FunctionToOptimize(
524+
function_name="target_function",
525+
file_path=Path("dummy_path.py"),
526+
parents=[]
527+
)
528+
529+
modified_code = add_codeflash_decorator_to_code(
530+
code=code,
531+
functions_to_optimize=[fto]
532+
)
533+
534+
expected_code = """
535+
from codeflash.benchmarking.codeflash_trace import codeflash_trace
536+
def function_with_nested():
537+
def inner_function():
538+
return "Hello from inner function"
539+
540+
return inner_function()
541+
542+
@codeflash_trace
543+
def target_function():
544+
return "Hello from target function after nested function"
545+
"""
546+
547+
assert modified_code.strip() == expected_code.strip()

0 commit comments

Comments
 (0)