Skip to content

Commit 6a85377

Browse files
committed
Update tracer.py
1 parent dbac51b commit 6a85377

File tree

7 files changed

+250
-322
lines changed

7 files changed

+250
-322
lines changed

codeflash/api/aiservice.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def optimize_python_code_line_profiler(
172172

173173
logger.info("Generating optimized candidates…")
174174
console.rule()
175-
if line_profiler_results=="":
175+
if line_profiler_results == "":
176176
logger.info("No LineProfiler results were provided, Skipping optimization.")
177177
console.rule()
178178
return []
@@ -204,7 +204,6 @@ def optimize_python_code_line_profiler(
204204
console.rule()
205205
return []
206206

207-
208207
def log_results(
209208
self,
210209
function_trace_id: str,

codeflash/code_utils/line_profile_utils.py

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
2+
23
from collections import defaultdict
34
from pathlib import Path
45
from typing import Union
@@ -12,7 +13,7 @@
1213
class LineProfilerDecoratorAdder(cst.CSTTransformer):
1314
"""Transformer that adds a decorator to a function with a specific qualified name."""
1415

15-
#TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
16+
# TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
1617
def __init__(self, qualified_name: str, decorator_name: str):
1718
"""Initialize the transformer.
1819
@@ -45,24 +46,19 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
4546
function_name = original_node.name.value
4647

4748
# Check if the current context path matches our target qualified name
48-
if self.context_stack==self.qualified_name_parts:
49+
if self.context_stack == self.qualified_name_parts:
4950
# Check if the decorator is already present
5051
has_decorator = any(
51-
self._is_target_decorator(decorator.decorator)
52-
for decorator in original_node.decorators
52+
self._is_target_decorator(decorator.decorator) for decorator in original_node.decorators
5353
)
5454

5555
# Only add the decorator if it's not already there
5656
if not has_decorator:
57-
new_decorator = cst.Decorator(
58-
decorator=cst.Name(value=self.decorator_name)
59-
)
57+
new_decorator = cst.Decorator(decorator=cst.Name(value=self.decorator_name))
6058

6159
# Add our new decorator to the existing decorators
6260
updated_decorators = [new_decorator] + list(updated_node.decorators)
63-
updated_node = updated_node.with_changes(
64-
decorators=tuple(updated_decorators)
65-
)
61+
updated_node = updated_node.with_changes(decorators=tuple(updated_decorators))
6662

6763
# Pop the context when we leave a function
6864
self.context_stack.pop()
@@ -76,8 +72,9 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs
7672
return decorator_node.func.value == self.decorator_name
7773
return False
7874

75+
7976
class ProfileEnableTransformer(cst.CSTTransformer):
80-
def __init__(self,filename):
77+
def __init__(self, filename):
8178
# Flag to track if we found the import statement
8279
self.found_import = False
8380
# Track indentation of the import statement
@@ -86,12 +83,14 @@ def __init__(self,filename):
8683

8784
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
8885
# Check if this is the line profiler import statement
89-
if (isinstance(original_node.module, cst.Name) and
90-
original_node.module.value == "line_profiler" and
91-
any(name.name.value == "profile" and
92-
(not name.asname or name.asname.name.value == "codeflash_line_profile")
93-
for name in original_node.names)):
94-
86+
if (
87+
isinstance(original_node.module, cst.Name)
88+
and original_node.module.value == "line_profiler"
89+
and any(
90+
name.name.value == "profile" and (not name.asname or name.asname.name.value == "codeflash_line_profile")
91+
for name in original_node.names
92+
)
93+
):
9594
self.found_import = True
9695
# Get the indentation from the original node
9796
if hasattr(original_node, "leading_lines"):
@@ -113,28 +112,31 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
113112
if isinstance(stmt, cst.SimpleStatementLine):
114113
for small_stmt in stmt.body:
115114
if isinstance(small_stmt, cst.ImportFrom):
116-
if (isinstance(small_stmt.module, cst.Name) and
117-
small_stmt.module.value == "line_profiler" and
118-
any(name.name.value == "profile" and
119-
(not name.asname or name.asname.name.value == "codeflash_line_profile")
120-
for name in small_stmt.names)):
115+
if (
116+
isinstance(small_stmt.module, cst.Name)
117+
and small_stmt.module.value == "line_profiler"
118+
and any(
119+
name.name.value == "profile"
120+
and (not name.asname or name.asname.name.value == "codeflash_line_profile")
121+
for name in small_stmt.names
122+
)
123+
):
121124
import_index = i
122125
break
123126
if import_index is not None:
124127
break
125128

126129
if import_index is not None:
127130
# Create the new enable statement to insert after the import
128-
enable_statement = cst.parse_statement(
129-
f"codeflash_line_profile.enable(output_prefix='{self.filename}')"
130-
)
131+
enable_statement = cst.parse_statement(f"codeflash_line_profile.enable(output_prefix='{self.filename}')")
131132

132133
# Insert the new statement after the import statement
133134
new_body.insert(import_index + 1, enable_statement)
134135

135136
# Create a new module with the updated body
136137
return updated_node.with_changes(body=new_body)
137138

139+
138140
def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
139141
"""Add a decorator to a function with the exact qualified name in the source code.
140142
@@ -156,6 +158,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
156158
# Convert the modified CST back to source code
157159
return modified_module
158160

161+
159162
def add_profile_enable(original_code: str, line_profile_output_file: str) -> str:
160163
# TODO modify by using a libcst transformer
161164
module = cst.parse_module(original_code)
@@ -178,9 +181,7 @@ def leave_Module(self, original_node, updated_node):
178181
import_node = cst.parse_statement(self.import_statement)
179182

180183
# Add the import to the module's body
181-
return updated_node.with_changes(
182-
body=[import_node] + list(updated_node.body)
183-
)
184+
return updated_node.with_changes(body=[import_node] + list(updated_node.body))
184185

185186
def visit_ImportFrom(self, node):
186187
# Check if the profile is already imported from line_profiler
@@ -192,15 +193,15 @@ def visit_ImportFrom(self, node):
192193

193194
def add_decorator_imports(function_to_optimize, code_context):
194195
"""Adds a profile decorator to a function in a Python file and all its helper functions."""
195-
#self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
196-
#grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
196+
# self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
197+
# grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
197198
file_paths = defaultdict(list)
198199
line_profile_output_file = get_run_tmp_file(Path("baseline_lprof"))
199200
file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name)
200201
for elem in code_context.helper_functions:
201202
file_paths[elem.file_path].append(elem.qualified_name)
202-
for file_path,fns_present in file_paths.items():
203-
#open file
203+
for file_path, fns_present in file_paths.items():
204+
# open file
204205
file_contents = file_path.read_text("utf-8")
205206
# parse to cst
206207
module_node = cst.parse_module(file_contents)
@@ -216,8 +217,8 @@ def add_decorator_imports(function_to_optimize, code_context):
216217
# write to file
217218
with open(file_path, "w", encoding="utf-8") as file:
218219
file.write(modified_code)
219-
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
220+
# Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
220221
file_contents = function_to_optimize.file_path.read_text("utf-8")
221-
modified_code = add_profile_enable(file_contents,str(line_profile_output_file))
222-
function_to_optimize.file_path.write_text(modified_code,"utf-8")
222+
modified_code = add_profile_enable(file_contents, str(line_profile_output_file))
223+
function_to_optimize.file_path.write_text(modified_code, "utf-8")
223224
return line_profile_output_file

0 commit comments

Comments
 (0)