Skip to content

Commit 89e72b7

Browse files
committed
refactored to make a new category for line profiler tests
1 parent 81a6b78 commit 89e72b7

File tree

8 files changed

+181
-100
lines changed

8 files changed

+181
-100
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from code_to_optimize.bubble_sort_in_class import BubbleSortClass
2+
3+
4+
def sort_classmethod(x):
5+
y = BubbleSortClass()
6+
return y.sorter(x)

codeflash/code_utils/lprof_utils.py

Lines changed: 61 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,59 @@
33
from pathlib import Path
44
from codeflash.code_utils.code_utils import get_run_tmp_file
55

6-
def add_decorator_cst(module_node, function_name, decorator_name):
7-
"""Adds a decorator to a function definition in a LibCST module node."""
6+
def add_decorator_cst(module_node, function_path, decorator_name):
7+
"""
8+
Adds a decorator to a function or method definition in a LibCST module node.
9+
10+
Args:
11+
module_node: LibCST module node
12+
function_path: String path to the function (e.g., 'function_name' or 'ClassName.method_name')
13+
decorator_name: Name of the decorator to add
14+
"""
15+
path_parts = function_path.split('.')
816

917
class AddDecoratorTransformer(cst.CSTTransformer):
18+
def __init__(self):
19+
super().__init__()
20+
self.current_class = None
21+
22+
def visit_ClassDef(self, node):
23+
# Track when we enter a class that matches our path
24+
if len(path_parts) > 1 and node.name.value == path_parts[0]:
25+
self.current_class = node.name.value
26+
return True
27+
28+
def leave_ClassDef(self, original_node, updated_node):
29+
# Reset class tracking when leaving a class node
30+
if self.current_class == original_node.name.value:
31+
self.current_class = None
32+
return updated_node
33+
1034
def leave_FunctionDef(self, original_node, updated_node):
11-
if original_node.name.value == function_name:
12-
new_decorator = cst.Decorator(
13-
decorator=cst.Name(value=decorator_name)
14-
)
35+
# Handle standalone functions
36+
if len(path_parts) == 1 and original_node.name.value == path_parts[0] and self.current_class is None:
37+
return self._add_decorator(updated_node)
38+
# Handle class methods
39+
elif len(path_parts) == 2 and self.current_class == path_parts[0] and original_node.name.value == path_parts[1]:
40+
return self._add_decorator(updated_node)
41+
return updated_node
1542

16-
updated_decorators = list(updated_node.decorators)
17-
updated_decorators.insert(0, new_decorator)
43+
def _add_decorator(self, node):
44+
# Create and add the decorator
45+
new_decorator = cst.Decorator(
46+
decorator=cst.Name(value=decorator_name)
47+
)
1848

19-
return updated_node.with_changes(decorators=updated_decorators)
20-
return updated_node
49+
# Check if this decorator already exists
50+
for decorator in node.decorators:
51+
if (isinstance(decorator.decorator, cst.Name) and
52+
decorator.decorator.value == decorator_name):
53+
return node # Decorator already exists
54+
55+
updated_decorators = list(node.decorators)
56+
updated_decorators.insert(0, new_decorator)
57+
58+
return node.with_changes(decorators=updated_decorators)
2159

2260
transformer = AddDecoratorTransformer()
2361
updated_module = module_node.visit(transformer)
@@ -83,21 +121,30 @@ def visit_ImportFrom(self, node):
83121
self.has_import = True
84122

85123

86-
def add_decorator_imports(file_paths, fn_list, db_file):
124+
def add_decorator_imports(function_to_optimize, code_context):
125+
#self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root
126+
#todo change function signature to get filepaths of fn, helpers and db
127+
# modify libcst parser to visit with qualified name
128+
file_paths = list()
129+
fn_list = list()
130+
db_file = get_run_tmp_file(Path("baseline"))
131+
file_paths.append(function_to_optimize.file_path)
132+
fn_list.append(function_to_optimize.qualified_name)
133+
for elem in code_context.helper_functions:
134+
file_paths.append(elem.file_path)
135+
fn_list.append(elem.qualified_name)
87136
"""Adds a decorator to a function in a Python file."""
88137
for file_path, fn_name in zip(file_paths, fn_list):
89138
#open file
90139
with open(file_path, "r", encoding="utf-8") as file:
91140
file_contents = file.read()
92-
93141
# parse to cst
94142
module_node = cst.parse_module(file_contents)
95143
# add decorator
96144
module_node = add_decorator_cst(module_node, fn_name, 'profile')
97145
# add imports
98146
# Create a transformer to add the import
99147
transformer = ImportAdder("from line_profiler import profile")
100-
101148
# Apply the transformer to add the import
102149
module_node = module_node.visit(transformer)
103150
modified_code = isort.code(module_node.code, float_to_top=True)
@@ -110,6 +157,7 @@ def add_decorator_imports(file_paths, fn_list, db_file):
110157
modified_code = add_profile_enable(file_contents,db_file)
111158
with open(file_paths[0],'w') as f:
112159
f.write(modified_code)
160+
return db_file
113161

114162

115163
def prepare_lprofiler_files(prefix: str = "") -> tuple[Path]:

codeflash/models/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,9 +216,9 @@ class FunctionParent:
216216
class OriginalCodeBaseline(BaseModel):
217217
behavioral_test_results: TestResults
218218
benchmarking_test_results: TestResults
219+
lprofiler_test_results: str
219220
runtime: int
220221
coverage_results: Optional[CoverageData]
221-
lprof_results: str
222222

223223

224224
class CoverageStatus(Enum):
@@ -512,3 +512,4 @@ class FunctionCoverage:
512512
class TestingMode(enum.Enum):
513513
BEHAVIOR = "behavior"
514514
PERFORMANCE = "performance"
515+
LPROF = "lprof"

codeflash/optimization/function_optimizer.py

Lines changed: 49 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@
3737
)
3838
from codeflash.code_utils.formatter import format_code, sort_imports
3939
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
40+
from codeflash.code_utils.lprof_utils import add_decorator_imports
4041
from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests
4142
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
4243
from codeflash.code_utils.time_utils import humanize_runtime
43-
from codeflash.code_utils.lprof_utils import add_decorator_imports, prepare_lprofiler_files
4444
from codeflash.context import code_context_extractor
4545
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
4646
from codeflash.either import Failure, Success, is_successful
@@ -65,10 +65,10 @@
6565
from codeflash.verification.concolic_testing import generate_concolic_tests
6666
from codeflash.verification.equivalence import compare_test_results
6767
from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture
68-
from codeflash.verification.parse_test_output import parse_test_results
6968
from codeflash.verification.parse_lprof_test_output import parse_lprof_results
69+
from codeflash.verification.parse_test_output import parse_test_results
7070
from codeflash.verification.test_results import TestResults, TestType
71-
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests
71+
from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_lprof_tests
7272
from codeflash.verification.verification_utils import get_test_file_path
7373
from codeflash.verification.verifier import generate_tests
7474

@@ -78,7 +78,7 @@
7878
from codeflash.either import Result
7979
from codeflash.models.models import CoverageData, FunctionSource, OptimizedCandidate
8080
from codeflash.verification.verification_utils import TestConfig
81-
from collections import deque
81+
8282

8383
class FunctionOptimizer:
8484
def __init__(
@@ -209,6 +209,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
209209
and "." in function_source.qualified_name
210210
):
211211
file_path_to_helper_classes[function_source.file_path].add(function_source.qualified_name.split(".")[0])
212+
212213
baseline_result = self.establish_original_code_baseline( # this needs better typing
213214
code_context=code_context,
214215
original_helper_code=original_helper_code,
@@ -232,27 +233,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
232233
return Failure("The threshold for test coverage was not met.")
233234

234235
best_optimization = None
235-
lprof_generated_results = []
236-
logger.info(f"Adding more candidates based on lineprof info, calling ai service")
237-
with concurrent.futures.ThreadPoolExecutor(max_workers= N_TESTS_TO_GENERATE + 2) as executor:
238-
future_optimization_candidates_lp = executor.submit(self.aiservice_client.optimize_python_code_line_profiler,
239-
source_code=code_context.read_writable_code,
240-
dependency_code=code_context.read_only_context_code,
241-
trace_id=self.function_trace_id,
242-
line_profiler_results=original_code_baseline.lprof_results,
243-
num_candidates = 10,
244-
experiment_metadata = None)
245-
future = [future_optimization_candidates_lp]
246-
concurrent.futures.wait(future)
247-
lprof_generated_results = future[0].result()
248-
if len(lprof_generated_results)==0:
249-
logger.info(f"Generated tests with line profiler failed.")
250-
else:
251-
logger.info(f"Generated tests with line profiler succeeded. Appending to optimization candidates.")
252-
logger.info(f"initial optimization candidates: {len(optimizations_set.control)}")
253-
optimizations_set.control.extend(lprof_generated_results)
254-
logger.info(f"After adding optimization candidates: {len(optimizations_set.control)}")
255-
#append to optimization candidates
236+
256237
for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]):
257238
if candidates is None:
258239
continue
@@ -782,7 +763,7 @@ def establish_original_code_baseline(
782763
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
783764
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
784765
success = True
785-
lprof_results = ''
766+
786767
test_env = os.environ.copy()
787768
test_env["CODEFLASH_TEST_ITERATION"] = "0"
788769
test_env["CODEFLASH_TRACER_DISABLE"] = "1"
@@ -793,7 +774,6 @@ def establish_original_code_baseline(
793774
test_env["PYTHONPATH"] += os.pathsep + str(self.args.project_root)
794775

795776
coverage_results = None
796-
lprofiler_results = None
797777
# Instrument codeflash capture
798778
try:
799779
instrument_codeflash_capture(
@@ -806,7 +786,6 @@ def establish_original_code_baseline(
806786
optimization_iteration=0,
807787
testing_time=TOTAL_LOOPING_TIME,
808788
enable_coverage=test_framework == "pytest",
809-
enable_lprofiler=False,
810789
code_context=code_context,
811790
)
812791
finally:
@@ -822,42 +801,30 @@ def establish_original_code_baseline(
822801
return Failure("Failed to establish a baseline for the original code - bevhavioral tests failed.")
823802
if not coverage_critic(coverage_results, self.args.test_framework):
824803
return Failure("The threshold for test coverage was not met.")
825-
#Running lprof now
826-
try:
827-
#add decorator here and import too
828-
lprofiler_database_file = prepare_lprofiler_files("baseline")
829-
#add decorator config to file, need to delete afterwards
830-
files_to_instrument = [self.function_to_optimize.file_path]
831-
fns_to_instrument = [self.function_to_optimize.function_name]
832-
for helper_obj in code_context.helper_functions:
833-
files_to_instrument.append(helper_obj.file_path)
834-
fns_to_instrument.append(helper_obj.qualified_name)
835-
add_decorator_imports(files_to_instrument,fns_to_instrument, lprofiler_database_file)
836-
#output doesn't matter, just need to run it
837-
lprof_cmd_results, _ = self.run_and_parse_tests(
838-
testing_type=TestingMode.BEHAVIOR,
839-
test_env=test_env,
840-
test_files=self.test_files,
841-
optimization_iteration=0,
842-
testing_time=TOTAL_LOOPING_TIME,
843-
enable_coverage=False,
844-
enable_lprofiler=test_framework == "pytest",
845-
code_context=code_context,
846-
lprofiler_database_file=lprofiler_database_file,
847-
)
848-
#real magic happens here
849-
lprof_results = parse_lprof_results(lprofiler_database_file)
850-
except Exception as e:
851-
logger.warning(f"Failed to run lprof for {self.function_to_optimize.function_name}. SKIPPING OPTIMIZING THIS FUNCTION.")
852-
console.rule()
853-
console.print(f"Failed to run lprof for {self.function_to_optimize.function_name}")
854-
console.rule()
855-
finally:
856-
# Remove decorators and lineprof import
857-
self.write_code_and_helpers(
858-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
859-
)
860804
if test_framework == "pytest":
805+
try:
806+
lprofiler_database_file = add_decorator_imports(
807+
self.function_to_optimize, code_context)
808+
lprof_results, _ = self.run_and_parse_tests(
809+
testing_type=TestingMode.LPROF,
810+
test_env=test_env,
811+
test_files=self.test_files,
812+
optimization_iteration=0,
813+
testing_time=TOTAL_LOOPING_TIME,
814+
enable_coverage=False,
815+
code_context=code_context,
816+
lprofiler_database_file=lprofiler_database_file,
817+
)
818+
finally:
819+
# Remove codeflash capture
820+
self.write_code_and_helpers(
821+
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
822+
)
823+
if not lprof_results:
824+
logger.warning(
825+
f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}"
826+
)
827+
console.rule()
861828
benchmarking_results, _ = self.run_and_parse_tests(
862829
testing_type=TestingMode.PERFORMANCE,
863830
test_env=test_env,
@@ -867,6 +834,7 @@ def establish_original_code_baseline(
867834
enable_coverage=False,
868835
code_context=code_context,
869836
)
837+
870838
else:
871839
benchmarking_results = TestResults()
872840
start_time: float = time.time()
@@ -928,7 +896,7 @@ def establish_original_code_baseline(
928896
benchmarking_test_results=benchmarking_results,
929897
runtime=total_timing,
930898
coverage_results=coverage_results,
931-
lprof_results=lprof_results,
899+
lprofiler_test_results=lprof_results,
932900
),
933901
functions_to_remove,
934902
)
@@ -1060,12 +1028,11 @@ def run_and_parse_tests(
10601028
testing_time: float = TOTAL_LOOPING_TIME,
10611029
*,
10621030
enable_coverage: bool = False,
1063-
enable_lprofiler: bool = False,
10641031
pytest_min_loops: int = 5,
10651032
pytest_max_loops: int = 100_000,
10661033
code_context: CodeOptimizationContext | None = None,
10671034
unittest_loop_index: int | None = None,
1068-
lprofiler_database_file: str | None = None,
1035+
lprofiler_database_file: Path | None = None,
10691036
) -> tuple[TestResults, CoverageData | None]:
10701037
coverage_database_file = None
10711038
coverage_config_file = None
@@ -1079,7 +1046,19 @@ def run_and_parse_tests(
10791046
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
10801047
verbose=True,
10811048
enable_coverage=enable_coverage,
1082-
enable_lprofiler=enable_lprofiler,
1049+
)
1050+
elif testing_type == TestingMode.LPROF:
1051+
result_file_path, run_result = run_lprof_tests(
1052+
test_files,
1053+
cwd=self.project_root,
1054+
test_env=test_env,
1055+
pytest_cmd=self.test_cfg.pytest_cmd,
1056+
pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT,
1057+
pytest_target_runtime_seconds=testing_time,
1058+
pytest_min_loops=pytest_min_loops,
1059+
pytest_max_loops=pytest_max_loops,
1060+
test_framework=self.test_cfg.test_framework,
1061+
lprofiler_database_file=lprofiler_database_file
10831062
)
10841063
elif testing_type == TestingMode.PERFORMANCE:
10851064
result_file_path, run_result = run_benchmarking_tests(
@@ -1108,7 +1087,7 @@ def run_and_parse_tests(
11081087
f"stdout: {run_result.stdout}\n"
11091088
f"stderr: {run_result.stderr}\n"
11101089
)
1111-
if not enable_lprofiler:
1090+
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
11121091
results, coverage_results = parse_test_results(
11131092
test_xml_path=result_file_path,
11141093
test_files=test_files,
@@ -1122,11 +1101,9 @@ def run_and_parse_tests(
11221101
coverage_database_file=coverage_database_file,
11231102
coverage_config_file=coverage_config_file,
11241103
)
1125-
return results, coverage_results
11261104
else:
1127-
#maintaining the function signature for the lprofiler
1128-
return TestResults(), None
1129-
1105+
results, coverage_results = parse_lprof_results(lprofiler_database_file=lprofiler_database_file)
1106+
return results, coverage_results
11301107

11311108
def generate_and_instrument_tests(
11321109
self,

codeflash/verification/parse_lprof_test_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,4 @@ def parse_lprof_results(lprofiler_database_file: Path | None) -> str:
107107
else:
108108
with open(lprofiler_database_file,'rb') as f:
109109
stats = pickle.load(f)
110-
return show_text(stats)
110+
return show_text(stats), None

codeflash/verification/parse_test_output.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from codeflash.discovery.discover_unit_tests import discover_parameters_unittest
2323
from codeflash.models.models import CoverageData, TestFiles
24+
from codeflash.verification.parse_lprof_test_output import parse_lprof_results
2425
from codeflash.verification.test_results import (
2526
FunctionTestInvocation,
2627
InvocationId,

codeflash/verification/pytest_plugin.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import warnings
1313
from typing import TYPE_CHECKING, Any, Callable
1414
from unittest import TestCase
15-
import line_profiler
1615

1716
# PyTest Imports
1817
import pytest

0 commit comments

Comments
 (0)