Skip to content

Commit c00e324

Browse files
committed
works for any level of nested function
1 parent 89e72b7 commit c00e324

File tree

2 files changed

+96
-52
lines changed

2 files changed

+96
-52
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from code_to_optimize.bubble_sort_in_class import BubbleSortClass
1+
from code_to_optimize.bubble_sort_in_nested_class import WrapperClass
22

33

44
def sort_classmethod(x):
5-
y = BubbleSortClass()
5+
y = WrapperClass.BubbleSortClass()
66
return y.sorter(x)

codeflash/code_utils/lprof_utils.py

Lines changed: 94 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,109 @@
11
import isort
22
import libcst as cst
33
from pathlib import Path
4+
from typing import Union
5+
46
from codeflash.code_utils.code_utils import get_run_tmp_file
57

6-
def add_decorator_cst(module_node, function_path, decorator_name):
7-
"""
8-
Adds a decorator to a function or method definition in a LibCST module node.
98

10-
Args:
11-
module_node: LibCST module node
12-
function_path: String path to the function (e.g., 'function_name' or 'ClassName.method_name')
13-
decorator_name: Name of the decorator to add
14-
"""
15-
path_parts = function_path.split('.')
16-
17-
class AddDecoratorTransformer(cst.CSTTransformer):
18-
def __init__(self):
19-
super().__init__()
20-
self.current_class = None
21-
22-
def visit_ClassDef(self, node):
23-
# Track when we enter a class that matches our path
24-
if len(path_parts) > 1 and node.name.value == path_parts[0]:
25-
self.current_class = node.name.value
26-
return True
27-
28-
def leave_ClassDef(self, original_node, updated_node):
29-
# Reset class tracking when leaving a class node
30-
if self.current_class == original_node.name.value:
31-
self.current_class = None
32-
return updated_node
9+
class DecoratorAdder(cst.CSTTransformer):
10+
"""Transformer that adds a decorator to a function with a specific qualified name."""
3311

34-
def leave_FunctionDef(self, original_node, updated_node):
35-
# Handle standalone functions
36-
if len(path_parts) == 1 and original_node.name.value == path_parts[0] and self.current_class is None:
37-
return self._add_decorator(updated_node)
38-
# Handle class methods
39-
elif len(path_parts) == 2 and self.current_class == path_parts[0] and original_node.name.value == path_parts[1]:
40-
return self._add_decorator(updated_node)
41-
return updated_node
12+
def __init__(self, qualified_name: str, decorator_name: str):
13+
"""
14+
Initialize the transformer.
15+
16+
Args:
17+
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
18+
decorator_name: The name of the decorator to add.
19+
"""
20+
super().__init__()
21+
self.qualified_name_parts = qualified_name.split(".")
22+
self.decorator_name = decorator_name
23+
24+
# Track our current context path
25+
self.context_stack = []
26+
27+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
28+
# Track when we enter a class
29+
self.context_stack.append(node.name.value)
30+
31+
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
32+
# Pop the context when we leave a class
33+
self.context_stack.pop()
34+
return updated_node
35+
36+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
37+
# Track when we enter a function
38+
self.context_stack.append(node.name.value)
39+
40+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
41+
function_name = original_node.name.value
4242

43-
def _add_decorator(self, node):
44-
# Create and add the decorator
45-
new_decorator = cst.Decorator(
46-
decorator=cst.Name(value=decorator_name)
43+
# Check if the current context path matches our target qualified name
44+
if self._matches_qualified_path():
45+
# Check if the decorator is already present
46+
has_decorator = any(
47+
self._is_target_decorator(decorator.decorator)
48+
for decorator in original_node.decorators
4749
)
4850

49-
# Check if this decorator already exists
50-
for decorator in node.decorators:
51-
if (isinstance(decorator.decorator, cst.Name) and
52-
decorator.decorator.value == decorator_name):
53-
return node # Decorator already exists
51+
# Only add the decorator if it's not already there
52+
if not has_decorator:
53+
new_decorator = cst.Decorator(
54+
decorator=cst.Name(value=self.decorator_name)
55+
)
5456

55-
updated_decorators = list(node.decorators)
56-
updated_decorators.insert(0, new_decorator)
57+
# Add our new decorator to the existing decorators
58+
updated_decorators = [new_decorator] + list(updated_node.decorators)
59+
updated_node = updated_node.with_changes(
60+
decorators=tuple(updated_decorators)
61+
)
62+
63+
# Pop the context when we leave a function
64+
self.context_stack.pop()
65+
return updated_node
66+
67+
def _matches_qualified_path(self) -> bool:
68+
"""Check if the current context stack matches the qualified name."""
69+
if len(self.context_stack) != len(self.qualified_name_parts):
70+
return False
71+
72+
for i, name in enumerate(self.qualified_name_parts):
73+
if self.context_stack[i] != name:
74+
return False
75+
76+
return True
77+
78+
def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cst.Call]) -> bool:
79+
"""Check if a decorator matches our target decorator name."""
80+
if isinstance(decorator_node, cst.Name):
81+
return decorator_node.value == self.decorator_name
82+
elif isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
83+
return decorator_node.func.value == self.decorator_name
84+
return False
85+
86+
87+
def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
88+
"""
89+
Add a decorator to a function with the exact qualified name in the source code.
90+
91+
Args:
92+
module: The Python source code as a string.
93+
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
94+
decorator_name: The name of the decorator to add.
95+
96+
Returns:
97+
The modified source code as a string.
98+
"""
99+
# Parse the source code into a CST
57100

58-
return node.with_changes(decorators=updated_decorators)
101+
# Apply our transformer
102+
transformer = DecoratorAdder(qualified_name, decorator_name)
103+
modified_module = module.visit(transformer)
59104

60-
transformer = AddDecoratorTransformer()
61-
updated_module = module_node.visit(transformer)
62-
return updated_module
105+
# Convert the modified CST back to source code
106+
return modified_module
63107

64108
def add_profile_enable(original_code: str, db_file: str) -> str:
65109
module = cst.parse_module(original_code)
@@ -141,7 +185,7 @@ def add_decorator_imports(function_to_optimize, code_context):
141185
# parse to cst
142186
module_node = cst.parse_module(file_contents)
143187
# add decorator
144-
module_node = add_decorator_cst(module_node, fn_name, 'profile')
188+
module_node = add_decorator_to_qualified_function(module_node, fn_name, 'profile')
145189
# add imports
146190
# Create a transformer to add the import
147191
transformer = ImportAdder("from line_profiler import profile")

0 commit comments

Comments
 (0)