Skip to content

Commit a628dae

Browse files
committed
formatting changes
1 parent 7e764ec commit a628dae

File tree

3 files changed

+42
-41
lines changed

3 files changed

+42
-41
lines changed

codeflash/api/aiservice.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def optimize_python_code_line_profiler(
172172

173173
logger.info("Generating optimized candidates…")
174174
console.rule()
175+
if line_profiler_results=="":
176+
logger.info("No LineProfiler results were provided, Skipping optimization.")
177+
console.rule()
178+
return []
175179
try:
176180
response = self.make_ai_service_request("/optimize-line-profiler", payload=payload, timeout=600)
177181
except requests.exceptions.RequestException as e:

codeflash/code_utils/line_profile_utils.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
11
"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)"""
22
from collections import defaultdict
3+
from pathlib import Path
4+
from typing import Union
35

46
import isort
57
import libcst as cst
6-
from pathlib import Path
7-
from typing import Union, List
8-
from libcst import ImportFrom, ImportAlias, Name
98

109
from codeflash.code_utils.code_utils import get_run_tmp_file
1110

1211

1312
class LineProfilerDecoratorAdder(cst.CSTTransformer):
1413
"""Transformer that adds a decorator to a function with a specific qualified name."""
15-
#Todo we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
14+
15+
#TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure
1616
def __init__(self, qualified_name: str, decorator_name: str):
17-
"""
18-
Initialize the transformer.
17+
"""Initialize the transformer.
1918
2019
Args:
2120
qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func").
2221
decorator_name: The name of the decorator to add.
22+
2323
"""
2424
super().__init__()
2525
self.qualified_name_parts = qualified_name.split(".")
@@ -45,7 +45,7 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
4545
function_name = original_node.name.value
4646

4747
# Check if the current context path matches our target qualified name
48-
if self._matches_qualified_path():
48+
if self.context_stack==self.qualified_name_parts:
4949
# Check if the decorator is already present
5050
has_decorator = any(
5151
self._is_target_decorator(decorator.decorator)
@@ -68,22 +68,11 @@ def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.Fu
6868
self.context_stack.pop()
6969
return updated_node
7070

71-
def _matches_qualified_path(self) -> bool:
72-
"""Check if the current context stack matches the qualified name."""
73-
if len(self.context_stack) != len(self.qualified_name_parts):
74-
return False
75-
76-
for i, name in enumerate(self.qualified_name_parts):
77-
if self.context_stack[i] != name:
78-
return False
79-
80-
return True
81-
8271
def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cst.Call]) -> bool:
8372
"""Check if a decorator matches our target decorator name."""
8473
if isinstance(decorator_node, cst.Name):
8574
return decorator_node.value == self.decorator_name
86-
elif isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
75+
if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name):
8776
return decorator_node.func.value == self.decorator_name
8877
return False
8978

@@ -147,8 +136,7 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c
147136
return updated_node.with_changes(body=new_body)
148137

149138
def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
150-
"""
151-
Add a decorator to a function with the exact qualified name in the source code.
139+
"""Add a decorator to a function with the exact qualified name in the source code.
152140
153141
Args:
154142
module: The Python source code as a string.
@@ -157,6 +145,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
157145
158146
Returns:
159147
The modified source code as a string.
148+
160149
"""
161150
# Parse the source code into a CST
162151

@@ -168,7 +157,7 @@ def add_decorator_to_qualified_function(module, qualified_name, decorator_name):
168157
return modified_module
169158

170159
def add_profile_enable(original_code: str, line_profile_output_file: str) -> str:
171-
# todo modify by using a libcst transformer
160+
# TODO modify by using a libcst transformer
172161
module = cst.parse_module(original_code)
173162
transformer = ProfileEnableTransformer(line_profile_output_file)
174163
modified_module = module.visit(transformer)
@@ -217,7 +206,7 @@ def add_decorator_imports(function_to_optimize, code_context):
217206
module_node = cst.parse_module(file_contents)
218207
for fn_name in fns_present:
219208
# add decorator
220-
module_node = add_decorator_to_qualified_function(module_node, fn_name, 'codeflash_line_profile')
209+
module_node = add_decorator_to_qualified_function(module_node, fn_name, "codeflash_line_profile")
221210
# add imports
222211
# Create a transformer to add the import
223212
transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile")
@@ -231,4 +220,4 @@ def add_decorator_imports(function_to_optimize, code_context):
231220
file_contents = function_to_optimize.file_path.read_text("utf-8")
232221
modified_code = add_profile_enable(file_contents,str(line_profile_output_file))
233222
function_to_optimize.file_path.write_text(modified_code,"utf-8")
234-
return line_profile_output_file
223+
return line_profile_output_file

codeflash/optimization/function_optimizer.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,8 @@ def optimize_function(self) -> Result[BestOptimization, str]:
232232
):
233233
cleanup_paths(paths_to_cleanup)
234234
return Failure("The threshold for test coverage was not met.")
235-
# request for new optimizations but don't block execution, check for completion later, only adding to control set right now
235+
# request for new optimizations but don't block execution, check for completion later
236+
# adding to control and experiment set but with same traceid
236237
best_optimization = None
237238

238239
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
@@ -359,30 +360,36 @@ def determine_best_candidate(
359360
)
360361
console.rule()
361362
candidates = deque(candidates)
363+
# Start a new thread for AI service request, start loop in main thread
364+
# check if aiservice request is complete, when it is complete, append result to the candidates list
362365
with concurrent.futures.ThreadPoolExecutor() as executor:
363-
future_line_profile_results = executor.submit(self.aiservice_client.optimize_python_code_line_profiler,
364-
source_code=code_context.read_writable_code,
365-
dependency_code=code_context.read_only_context_code,
366-
trace_id=self.function_trace_id,
367-
line_profiler_results=original_code_baseline.line_profile_results['str_out'],
368-
num_candidates = 10,
369-
experiment_metadata = None)
366+
future_line_profile_results = executor.submit(
367+
self.aiservice_client.optimize_python_code_line_profiler,
368+
source_code=code_context.read_writable_code,
369+
dependency_code=code_context.read_only_context_code,
370+
trace_id=self.function_trace_id,
371+
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
372+
num_candidates=10,
373+
experiment_metadata=None,
374+
)
370375
try:
371376
candidate_index = 0
372377
done = False
378+
original_len = len(candidates)
373379
while candidates:
374-
#for candidate_index, candidate in enumerate(candidates, start=1):
380+
# for candidate_index, candidate in enumerate(candidates, start=1):
375381
done = True if future_line_profile_results is None else future_line_profile_results.done()
376382
if done and (future_line_profile_results is not None):
377383
line_profile_results = future_line_profile_results.result()
378384
candidates.extend(line_profile_results)
379-
logger.info(f"Added result from line profiler to candidates: {len(line_profile_results)}")
385+
original_len+= len(candidates)
386+
logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}")
380387
future_line_profile_results = None
381388
candidate_index += 1
382389
candidate = candidates.popleft()
383390
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True)
384391
get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True)
385-
logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:")
392+
logger.info(f"Optimization candidate {candidate_index}/{original_len}:")
386393
code_print(candidate.source_code)
387394
try:
388395
did_update = self.replace_function_and_helpers_with_optimized_code(
@@ -397,7 +404,9 @@ def determine_best_candidate(
397404
except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e:
398405
logger.error(e)
399406
self.write_code_and_helpers(
400-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
407+
self.function_to_optimize_source_code,
408+
original_helper_code,
409+
self.function_to_optimize.file_path,
401410
)
402411
continue
403412

@@ -781,7 +790,7 @@ def establish_original_code_baseline(
781790
original_helper_code: dict[Path, str],
782791
file_path_to_helper_classes: dict[Path, set[str]],
783792
) -> Result[tuple[OriginalCodeBaseline, list[str]], str]:
784-
line_profile_results = {'timings':{},'unit':0, 'str_out':''}
793+
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
785794
# For the original function - run the tests and get the runtime, plus coverage
786795
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
787796
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
@@ -826,8 +835,7 @@ def establish_original_code_baseline(
826835
return Failure("The threshold for test coverage was not met.")
827836
if test_framework == "pytest":
828837
try:
829-
line_profiler_output_file = add_decorator_imports(
830-
self.function_to_optimize, code_context)
838+
line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context)
831839
line_profile_results, _ = self.run_and_parse_tests(
832840
testing_type=TestingMode.LINE_PROFILE,
833841
test_env=test_env,
@@ -843,7 +851,7 @@ def establish_original_code_baseline(
843851
self.write_code_and_helpers(
844852
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
845853
)
846-
if line_profile_results['str_out']=='':
854+
if line_profile_results["str_out"] == "":
847855
logger.warning(
848856
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
849857
)
@@ -1081,7 +1089,7 @@ def run_and_parse_tests(
10811089
pytest_min_loops=1,
10821090
pytest_max_loops=1,
10831091
test_framework=self.test_cfg.test_framework,
1084-
line_profiler_output_file=line_profiler_output_file
1092+
line_profiler_output_file=line_profiler_output_file,
10851093
)
10861094
elif testing_type == TestingMode.PERFORMANCE:
10871095
result_file_path, run_result = run_benchmarking_tests(

0 commit comments

Comments
 (0)