Skip to content

Commit 7ffe25f

Browse files
committed
works, optimization list length will be incorrectly displayed as we are dynamically popping and appending to the list, right now appending to just the control opt candidates, need to refactor that
1 parent c00e324 commit 7ffe25f

File tree

12 files changed

+361
-2300
lines changed

12 files changed

+361
-2300
lines changed

code_to_optimize/bubble_sort_deps.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1-
from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer
2-
from code_to_optimize.bubble_sort_dep2_swap import dep2_swap
3-
4-
51
def sorter_deps(arr):
6-
for i in range(len(arr)):
7-
for j in range(len(arr) - 1):
8-
if dep1_comparer(arr, j):
9-
dep2_swap(arr, j)
2+
n = len(arr)
3+
for i in range(n):
4+
# We use a flag to check if the array is already sorted
5+
swapped = False
6+
# Reduce the range of j, since the last i elements are already sorted
7+
for j in range(n - 1 - i):
8+
if arr[j] > arr[j + 1]:
9+
# Swap without a helper function
10+
arr[j], arr[j + 1] = arr[j + 1], arr[j]
11+
swapped = True
12+
# If no elements were swapped in the inner loop, break
13+
if not swapped:
14+
break
1015
return arr
11-
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from code_to_optimize.bubble_sort_in_nested_class import WrapperClass
2+
from line_profiler import profile as codeflash_line_profile
3+
4+
5+
@codeflash_line_profile
6+
def sort_classmethod(x):
7+
y = WrapperClass.BubbleSortClass()
8+
return y.sorter(x)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from code_to_optimize.bubble_sort_deps import sorter_deps
2+
3+
4+
def test_sort():
5+
input = [5, 4, 3, 2, 1, 0]
6+
output = sorter_deps(input)
7+
assert output == [0, 1, 2, 3, 4, 5]
8+
9+
input = [5.0, 4.0, 3.0, 2.0, 1.0, 0.0]
10+
output = sorter_deps(input)
11+
assert output == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
12+
13+
input = list(reversed(range(5000)))
14+
output = sorter_deps(input)
15+
assert output == list(range(5000))

codeflash/code_utils/lprof_utils.py renamed to codeflash/code_utils/line_profile_utils.py

Lines changed: 88 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
2+
from collections import defaultdict
3+
14
import isort
25
import libcst as cst
36
from pathlib import Path
4-
from typing import Union
7+
from typing import Union, List
8+
from libcst import ImportFrom, ImportAlias, Name
59

610
from codeflash.code_utils.code_utils import get_run_tmp_file
711

812

9-
class DecoratorAdder(cst.CSTTransformer):
13+
class LineProfilerDecoratorAdder(cst.CSTTransformer):
1014
"""Transformer that adds a decorator to a function with a specific qualified name."""
11-
15+
#Todo we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
1216
def __init__(self, qualified_name: str, decorator_name: str):
1317
"""
1418
Initialize the transformer.
@@ -21,7 +25,7 @@ def __init__(self, qualified_name: str, decorator_name: str):
2125
self.qualified_name_parts = qualified_name.split(".")
2226
self.decorator_name = decorator_name
2327

24-
# Track our current context path
28+
# Track our current context path, only add when we encounter a class
2529
self.context_stack = []
2630

2731
def visit_ClassDef(self, node: cst.ClassDef) -> None:
@@ -83,6 +87,64 @@ def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cs
8387
return decorator_node.func.value == self.decorator_name
8488
return False
8589

90+
class ProfileEnableTransformer(cst.CSTTransformer):
91+
def __init__(self,filename):
92+
# Flag to track if we found the import statement
93+
self.found_import = False
94+
# Track indentation of the import statement
95+
self.import_indentation = None
96+
self.filename = filename
97+
98+
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
99+
# Check if this is the line profiler import statement
100+
if (isinstance(original_node.module, cst.Name) and
101+
original_node.module.value == "line_profiler" and
102+
any(name.name.value == "profile" and
103+
(not name.asname or name.asname.name.value == "codeflash_line_profile")
104+
for name in original_node.names)):
105+
106+
self.found_import = True
107+
# Get the indentation from the original node
108+
if hasattr(original_node, "leading_lines"):
109+
leading_whitespace = original_node.leading_lines[-1].whitespace if original_node.leading_lines else ""
110+
self.import_indentation = leading_whitespace
111+
112+
return updated_node
113+
114+
def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module:
115+
if not self.found_import:
116+
return updated_node
117+
118+
# Create a list of statements from the original module
119+
new_body = list(updated_node.body)
120+
121+
# Find the index of the import statement
122+
import_index = None
123+
for i, stmt in enumerate(new_body):
124+
if isinstance(stmt, cst.SimpleStatementLine):
125+
for small_stmt in stmt.body:
126+
if isinstance(small_stmt, cst.ImportFrom):
127+
if (isinstance(small_stmt.module, cst.Name) and
128+
small_stmt.module.value == "line_profiler" and
129+
any(name.name.value == "profile" and
130+
(not name.asname or name.asname.name.value == "codeflash_line_profile")
131+
for name in small_stmt.names)):
132+
import_index = i
133+
break
134+
if import_index is not None:
135+
break
136+
137+
if import_index is not None:
138+
# Create the new enable statement to insert after the import
139+
enable_statement = cst.parse_statement(
140+
f"codeflash_line_profile.enable(output_prefix='{self.filename}')"
141+
)
142+
143+
# Insert the new statement after the import statement
144+
new_body.insert(import_index + 1, enable_statement)
145+
146+
# Create a new module with the updated body
147+
return updated_node.with_changes(body=new_body)
86148

87149
def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
88150
"""
@@ -99,48 +161,22 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
99161
# Parse the source code into a CST
100162

101163
# Apply our transformer
102-
transformer = DecoratorAdder(qualified_name, decorator_name)
164+
transformer = LineProfilerDecoratorAdder(qualified_name, decorator_name)
103165
modified_module = module.visit(transformer)
104166

105167
# Convert the modified CST back to source code
106168
return modified_module
107169

108-
def add_profile_enable(original_code: str, db_file: str) -> str:
170+
def add_profile_enable(original_code: str, line_profile_output_file: str) -> str:
171+
# todo modify by using a libcst transformer
109172
module = cst.parse_module(original_code)
110-
found_index = -1
111-
112-
for idx, statement in enumerate(module.body):
113-
if isinstance(statement, cst.SimpleStatementLine):
114-
for stmt in statement.body:
115-
if isinstance(stmt, cst.ImportFrom):
116-
if stmt.module and stmt.module.value == 'line_profiler':
117-
for name in stmt.names:
118-
if isinstance(name, cst.ImportAlias):
119-
if name.name.value == 'profile' and name.asname is None:
120-
found_index = idx
121-
break
122-
if found_index != -1:
123-
break
124-
if found_index != -1:
125-
break
126-
127-
if found_index == -1:
128-
return original_code # or raise an exception if the import is not found
129-
130-
# Create the new line to insert
131-
new_line = f"profile.enable(output_prefix='{db_file}')\n"
132-
new_statement = cst.parse_statement(new_line)
133-
134-
# Insert the new statement into the module's body
135-
new_body = list(module.body)
136-
new_body.insert(found_index + 1, new_statement)
137-
modified_module = module.with_changes(body=new_body)
138-
173+
transformer = ProfileEnableTransformer(line_profile_output_file)
174+
modified_module = module.visit(transformer)
139175
return modified_module.code
140176

141177

142178
class ImportAdder(cst.CSTTransformer):
143-
def __init__(self, import_statement='from line_profiler import profile'):
179+
def __init__(self, import_statement):
144180
self.import_statement = import_statement
145181
self.has_import = False
146182

@@ -166,45 +202,36 @@ def visit_ImportFrom(self, node):
166202

167203

168204
def add_decorator_imports(function_to_optimize, code_context):
205+
"""Adds a profile decorator to a function in a Python file and all its helper functions."""
169206
#self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
170-
#todo change function signature to get filepaths of fn, helpers and db
171-
# modify libcst parser to visit with qualified name
172-
file_paths = list()
173-
fn_list = list()
174-
db_file = get_run_tmp_file(Path("baseline"))
175-
file_paths.append(function_to_optimize.file_path)
176-
fn_list.append(function_to_optimize.qualified_name)
207+
#grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile
208+
file_paths = defaultdict(list)
209+
line_profile_output_file = get_run_tmp_file(Path("baseline_lprof"))
210+
file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name)
177211
for elem in code_context.helper_functions:
178-
file_paths.append(elem.file_path)
179-
fn_list.append(elem.qualified_name)
180-
"""Adds a decorator to a function in a Python file."""
181-
for file_path, fn_name in zip(file_paths, fn_list):
212+
file_paths[elem.file_path].append(elem.qualified_name)
213+
for file_path,fns_present in file_paths.items():
182214
#open file
183215
with open(file_path, "r", encoding="utf-8") as file:
184216
file_contents = file.read()
185217
# parse to cst
186218
module_node = cst.parse_module(file_contents)
187-
# add decorator
188-
module_node = add_decorator_to_qualified_function(module_node, fn_name, 'profile')
219+
for fn_name in fns_present:
220+
# add decorator
221+
module_node = add_decorator_to_qualified_function(module_node, fn_name, 'codeflash_line_profile')
189222
# add imports
190223
# Create a transformer to add the import
191-
transformer = ImportAdder("from line_profiler import profile")
224+
transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile")
192225
# Apply the transformer to add the import
193226
module_node = module_node.visit(transformer)
194227
modified_code = isort.code(module_node.code, float_to_top=True)
195228
# write to file
196229
with open(file_path, "w", encoding="utf-8") as file:
197230
file.write(modified_code)
198-
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files, can use libcst seems like an overkill, will go just with some simple string manipulation
199-
with open(file_paths[0],'r') as f:
231+
#Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files
232+
with open(function_to_optimize.file_path,'r') as f:
200233
file_contents = f.read()
201-
modified_code = add_profile_enable(file_contents,db_file)
202-
with open(file_paths[0],'w') as f:
234+
modified_code = add_profile_enable(file_contents,str(line_profile_output_file))
235+
with open(function_to_optimize.file_path,'w') as f:
203236
f.write(modified_code)
204-
return db_file
205-
206-
207-
def prepare_lprofiler_files(prefix: str = "") -> tuple[Path]:
208-
"""Prepare line profiler output file."""
209-
lprofiler_database_file = get_run_tmp_file(Path(prefix))
210-
return lprofiler_database_file
237+
return line_profile_output_file

0 commit comments

Comments
 (0)