Skip to content

Commit f54c1ad

Browse files
Merge pull request #35 from codeflash-ai/line-profiler
Integrate Line Profiler in Codeflash CF-470
2 parents 01b7e94 + 3cab8ea commit f54c1ad

File tree

12 files changed

+2077
-98
lines changed

12 files changed

+2077
-98
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from code_to_optimize.bubble_sort_in_class import BubbleSortClass
2+
3+
4+
def sort_classmethod(x):
5+
y = BubbleSortClass()
6+
return y.sorter(x)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from code_to_optimize.bubble_sort_in_nested_class import WrapperClass
2+
3+
4+
def sort_classmethod(x):
5+
y = WrapperClass.BubbleSortClass()
6+
return y.sorter(x)

codeflash/api/aiservice.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,76 @@ def optimize_python_code(
135135
console.rule()
136136
return []
137137

138+
def optimize_python_code_line_profiler(
139+
self,
140+
source_code: str,
141+
dependency_code: str,
142+
trace_id: str,
143+
line_profiler_results: str,
144+
num_candidates: int = 10,
145+
experiment_metadata: ExperimentMetadata | None = None,
146+
) -> list[OptimizedCandidate]:
147+
"""Optimize the given python code for performance by making a request to the Django endpoint.
148+
149+
Parameters
150+
----------
151+
- source_code (str): The python code to optimize.
152+
- dependency_code (str): The dependency code used as read-only context for the optimization
153+
- trace_id (str): Trace id of optimization run
154+
- num_candidates (int): Number of optimization variants to generate. Default is 10.
155+
- experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization
156+
157+
Returns
158+
-------
159+
- List[OptimizationCandidate]: A list of Optimization Candidates.
160+
161+
"""
162+
payload = {
163+
"source_code": source_code,
164+
"dependency_code": dependency_code,
165+
"num_variants": num_candidates,
166+
"line_profiler_results": line_profiler_results,
167+
"trace_id": trace_id,
168+
"python_version": platform.python_version(),
169+
"experiment_metadata": experiment_metadata,
170+
"codeflash_version": codeflash_version,
171+
}
172+
173+
logger.info("Generating optimized candidates…")
174+
console.rule()
175+
if line_profiler_results=="":
176+
logger.info("No LineProfiler results were provided, Skipping optimization.")
177+
console.rule()
178+
return []
179+
try:
180+
response = self.make_ai_service_request("/optimize-line-profiler", payload=payload, timeout=600)
181+
except requests.exceptions.RequestException as e:
182+
logger.exception(f"Error generating optimized candidates: {e}")
183+
ph("cli-optimize-error-caught", {"error": str(e)})
184+
return []
185+
186+
if response.status_code == 200:
187+
optimizations_json = response.json()["optimizations"]
188+
logger.info(f"Generated {len(optimizations_json)} candidates.")
189+
console.rule()
190+
return [
191+
OptimizedCandidate(
192+
source_code=opt["source_code"],
193+
explanation=opt["explanation"],
194+
optimization_id=opt["optimization_id"],
195+
)
196+
for opt in optimizations_json
197+
]
198+
try:
199+
error = response.json()["error"]
200+
except Exception:
201+
error = response.text
202+
logger.error(f"Error generating optimized candidates: {response.status_code} - {error}")
203+
ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error})
204+
console.rule()
205+
return []
206+
207+
138208
def log_results(
139209
self,
140210
function_trace_id: str,
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
2+
from collections import defaultdict
3+
from pathlib import Path
4+
from typing import Union
5+
6+
import isort
7+
import libcst as cst
8+
9+
from codeflash.code_utils.code_utils import get_run_tmp_file
10+
11+
12+
class LineProfilerDecoratorAdder(cst.CSTTransformer):
13+
"""Transformer that adds a decorator to a function with a specific qualified name."""
14+
15+
#TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
16+
def __init__(self, qualified_name: str, decorator_name: str):
17+
"""Initialize the transformer.
18+
19+
Args:
20+
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
21+
decorator_name: The name of the decorator to add.
22+
23+
"""
24+
super().__init__()
25+
self.qualified_name_parts = qualified_name.split(".")
26+
self.decorator_name = decorator_name
27+
28+
# Track our current context path, only add when we encounter a class
29+
self.context_stack = []
30+
31+
def visit_ClassDef(self, node: cst.ClassDef) -> None:
32+
# Track when we enter a class
33+
self.context_stack.append(node.name.value)
34+
35+
def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef:
36+
# Pop the context when we leave a class
37+
self.context_stack.pop()
38+
return updated_node
39+
40+
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
41+
# Track when we enter a function
42+
self.context_stack.append(node.name.value)
43+
44+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
45+
function_name = original_node.name.value
46+
47+
# Check if the current context path matches our target qualified name
48+
if self.context_stack==self.qualified_name_parts:
49+
# Check if the decorator is already present
50+
has_decorator = any(
51+
self._is_target_decorator(decorator.decorator)
52+
for decorator in original_node.decorators
53+
)
54+
55+
# Only add the decorator if it's not already there
56+
if not has_decorator:
57+
new_decorator = cst.Decorator(
58+
decorator=cst.Name(value=self.decorator_name)
59+
)
60+
61+
# Add our new decorator to the existing decorators
62+
updated_decorators = [new_decorator] + list(updated_node.decorators)
63+
updated_node = updated_node.with_changes(
64+
decorators=tuple(updated_decorators)
65+
)
66+
67+
# Pop the context when we leave a function
68+
self.context_stack.pop()
69+
return updated_node
70+
71+
def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cst.Call]) -> bool:
72+
"""Check if a decorator matches our target decorator name."""
73+
if isinstance(decorator_node, cst.Name):
74+
return decorator_node.value == self.decorator_name
75+
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
76+
return decorator_node.func.value == self.decorator_name
77+
return False
78+
79+
class ProfileEnableTransformer(cst.CSTTransformer):
80+
def __init__(self,filename):
81+
# Flag to track if we found the import statement
82+
self.found_import = False
83+
# Track indentation of the import statement
84+
self.import_indentation = None
85+
self.filename = filename
86+
87+
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
88+
# Check if this is the line profiler import statement
89+
if (isinstance(original_node.module, cst.Name) and
90+
original_node.module.value == "line_profiler" and
91+
any(name.name.value == "profile" and
92+
(not name.asname or name.asname.name.value == "codeflash_line_profile")
93+
for name in original_node.names)):
94+
95+
self.found_import = True
96+
# Get the indentation from the original node
97+
if hasattr(original_node, "leading_lines"):
98+
leading_whitespace = original_node.leading_lines[-1].whitespace if original_node.leading_lines else ""
99+
self.import_indentation = leading_whitespace
100+
101+
return updated_node
102+
103+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
104+
if not self.found_import:
105+
return updated_node
106+
107+
# Create a list of statements from the original module
108+
new_body = list(updated_node.body)
109+
110+
# Find the index of the import statement
111+
import_index = None
112+
for i, stmt in enumerate(new_body):
113+
if isinstance(stmt, cst.SimpleStatementLine):
114+
for small_stmt in stmt.body:
115+
if isinstance(small_stmt, cst.ImportFrom):
116+
if (isinstance(small_stmt.module, cst.Name) and
117+
small_stmt.module.value == "line_profiler" and
118+
any(name.name.value == "profile" and
119+
(not name.asname or name.asname.name.value == "codeflash_line_profile")
120+
for name in small_stmt.names)):
121+
import_index = i
122+
break
123+
if import_index is not None:
124+
break
125+
126+
if import_index is not None:
127+
# Create the new enable statement to insert after the import
128+
enable_statement = cst.parse_statement(
129+
f"codeflash_line_profile.enable(output_prefix='{self.filename}')"
130+
)
131+
132+
# Insert the new statement after the import statement
133+
new_body.insert(import_index + 1, enable_statement)
134+
135+
# Create a new module with the updated body
136+
return updated_node.with_changes(body=new_body)
137+
138+
def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
139+
"""Add a decorator to a function with the exact qualified name in the source code.
140+
141+
Args:
142+
module: The Python source code as a string.
143+
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
144+
decorator_name: The name of the decorator to add.
145+
146+
Returns:
147+
The modified source code as a string.
148+
149+
"""
150+
# Parse the source code into a CST
151+
152+
# Apply our transformer
153+
transformer = LineProfilerDecoratorAdder(qualified_name, decorator_name)
154+
modified_module = module.visit(transformer)
155+
156+
# Convert the modified CST back to source code
157+
return modified_module
158+
159+
def add_profile_enable(original_code: str, line_profile_output_file: str) -> str:
160+
# TODO modify by using a libcst transformer
161+
module = cst.parse_module(original_code)
162+
transformer = ProfileEnableTransformer(line_profile_output_file)
163+
modified_module = module.visit(transformer)
164+
return modified_module.code
165+
166+
167+
class ImportAdder(cst.CSTTransformer):
168+
def __init__(self, import_statement):
169+
self.import_statement = import_statement
170+
self.has_import = False
171+
172+
def leave_Module(self, original_node, updated_node):
173+
# If the import is already there, don't add it again
174+
if self.has_import:
175+
return updated_node
176+
177+
# Parse the import statement into a CST node
178+
import_node = cst.parse_statement(self.import_statement)
179+
180+
# Add the import to the module's body
181+
return updated_node.with_changes(
182+
body=[import_node] + list(updated_node.body)
183+
)
184+
185+
def visit_ImportFrom(self, node):
186+
# Check if the profile is already imported from line_profiler
187+
if node.module and node.module.value == "line_profiler":
188+
for import_alias in node.names:
189+
if import_alias.name.value == "profile":
190+
self.has_import = True
191+
192+
193+
def add_decorator_imports(function_to_optimize, code_context):
194+
"""Adds a profile decorator to a function in a Python file and all its helper functions."""
195+
#self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
196+
#grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
197+
file_paths = defaultdict(list)
198+
line_profile_output_file = get_run_tmp_file(Path("baseline_lprof"))
199+
file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name)
200+
for elem in code_context.helper_functions:
201+
file_paths[elem.file_path].append(elem.qualified_name)
202+
for file_path,fns_present in file_paths.items():
203+
#open file
204+
file_contents = file_path.read_text("utf-8")
205+
# parse to cst
206+
module_node = cst.parse_module(file_contents)
207+
for fn_name in fns_present:
208+
# add decorator
209+
module_node = add_decorator_to_qualified_function(module_node, fn_name, "codeflash_line_profile")
210+
# add imports
211+
# Create a transformer to add the import
212+
transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile")
213+
# Apply the transformer to add the import
214+
module_node = module_node.visit(transformer)
215+
modified_code = isort.code(module_node.code, float_to_top=True)
216+
# write to file
217+
with open(file_path, "w", encoding="utf-8") as file:
218+
file.write(modified_code)
219+
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
220+
file_contents = function_to_optimize.file_path.read_text("utf-8")
221+
modified_code = add_profile_enable(file_contents,str(line_profile_output_file))
222+
function_to_optimize.file_path.write_text(modified_code,"utf-8")
223+
return line_profile_output_file

0 commit comments

Comments
 (0)