Skip to content

Commit d3788ec

Browse files
Merge pull request #874 from codeflash-ai/references-context
Python Version and function references as additional context
2 parents befcc0b + 824c53f commit d3788ec

File tree

6 files changed

+62
-19
lines changed

6 files changed

+62
-19
lines changed

.github/workflows/unit-tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ jobs:
2828
- name: install dependencies
2929
run: uv sync
3030

31-
- name: Install test-only dependencies (Python 3.13)
32-
if: matrix.python-version == '3.13'
31+
- name: Install test-only dependencies (Python 3.9 and 3.13)
32+
if: matrix.python-version == '3.9' || matrix.python-version == '3.13'
3333
run: uv sync --group tests
3434

3535
- name: Unit tests

codeflash/api/aiservice.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ def optimize_python_code_refinement(self, request: list[AIServiceRefinerRequest]
255255
"optimized_code_runtime": opt.optimized_code_runtime,
256256
"speedup": opt.speedup,
257257
"trace_id": opt.trace_id,
258+
"function_references": opt.function_references,
259+
"python_version": platform.python_version(),
258260
}
259261
for opt in request
260262
]
@@ -308,6 +310,7 @@ def get_new_explanation( # noqa: D417
308310
original_throughput: str | None = None,
309311
optimized_throughput: str | None = None,
310312
throughput_improvement: str | None = None,
313+
function_references: str | None = None,
311314
) -> str:
312315
"""Optimize the given python code for performance by making a request to the Django endpoint.
313316
@@ -327,6 +330,7 @@ def get_new_explanation( # noqa: D417
327330
- original_throughput: str | None - throughput for the baseline code (operations per second)
328331
- optimized_throughput: str | None - throughput for the optimized code (operations per second)
329332
- throughput_improvement: str | None - throughput improvement percentage
333+
- function_references: str | None - where the function is called in the codebase
330334
331335
Returns
332336
-------
@@ -349,6 +353,7 @@ def get_new_explanation( # noqa: D417
349353
"original_throughput": original_throughput,
350354
"optimized_throughput": optimized_throughput,
351355
"throughput_improvement": throughput_improvement,
356+
"function_references": function_references,
352357
}
353358
logger.info("loading|Generating explanation")
354359
console.rule()
@@ -373,7 +378,12 @@ def get_new_explanation( # noqa: D417
373378
return ""
374379

375380
def generate_ranking( # noqa: D417
376-
self, trace_id: str, diffs: list[str], optimization_ids: list[str], speedups: list[float]
381+
self,
382+
trace_id: str,
383+
diffs: list[str],
384+
optimization_ids: list[str],
385+
speedups: list[float],
386+
function_references: str | None = None,
377387
) -> list[int] | None:
378388
"""Optimize the given python code for performance by making a request to the Django endpoint.
379389
@@ -382,6 +392,7 @@ def generate_ranking( # noqa: D417
382392
- trace_id : unique uuid of function
383393
- diffs : list of unified diff strings of opt candidates
384394
- speedups : list of speedups of opt candidates
395+
- function_references : where the function is called in the codebase
385396
386397
Returns
387398
-------
@@ -394,6 +405,7 @@ def generate_ranking( # noqa: D417
394405
"speedups": speedups,
395406
"optimization_ids": optimization_ids,
396407
"python_version": platform.python_version(),
408+
"function_references": function_references,
397409
}
398410
logger.info("loading|Generating ranking")
399411
console.rule()
@@ -594,6 +606,7 @@ def get_optimization_review(
594606
"optimized_runtime": humanize_runtime(explanation.best_runtime_ns),
595607
"original_runtime": humanize_runtime(explanation.original_runtime_ns),
596608
"calling_fn_details": calling_fn_details,
609+
"python_version": platform.python_version(),
597610
}
598611
console.rule()
599612
try:

codeflash/code_utils/code_extractor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import ast
4+
import time
45
from dataclasses import dataclass
56
from itertools import chain
67
from pathlib import Path
@@ -1138,6 +1139,7 @@ def find_specific_function_in_file(
11381139
def get_fn_references_jedi(
11391140
source_code: str, file_path: Path, project_root: Path, target_function: str, target_class: str | None
11401141
) -> list[Path]:
1142+
start_time = time.perf_counter()
11411143
function_position: CodePosition = find_specific_function_in_file(
11421144
source_code, file_path, target_function, target_class
11431145
)
@@ -1146,6 +1148,8 @@ def get_fn_references_jedi(
11461148
# Get references to the function
11471149
references = script.get_references(line=function_position.line_no, column=function_position.col_no)
11481150
# Collect unique file paths where references are found
1151+
end_time = time.perf_counter()
1152+
logger.debug(f"Jedi for function references ran in {end_time - start_time:.2f} seconds")
11491153
reference_files = set()
11501154
for ref in references:
11511155
if ref.module_path:
@@ -1163,6 +1167,7 @@ def get_fn_references_jedi(
11631167
def get_opt_review_metrics(
11641168
source_code: str, file_path: Path, qualified_name: str, project_root: Path, tests_root: Path
11651169
) -> str:
1170+
start_time = time.perf_counter()
11661171
try:
11671172
qualified_name_split = qualified_name.rsplit(".", maxsplit=1)
11681173
if len(qualified_name_split) == 1:
@@ -1176,4 +1181,6 @@ def get_opt_review_metrics(
11761181
except Exception as e:
11771182
calling_fns_details = ""
11781183
logger.debug(f"Investigate {e}")
1184+
end_time = time.perf_counter()
1185+
logger.debug(f"Got function references in {end_time - start_time:.2f} seconds")
11791186
return calling_fns_details

codeflash/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class AIServiceRefinerRequest:
4444
trace_id: str
4545
original_line_profiler_results: str
4646
optimized_line_profiler_results: str
47+
function_references: str | None = None
4748

4849

4950
# If the method spam is in the class Ham, which is at the top level of the module eggs in the package foo, the fully

codeflash/optimization/function_optimizer.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ def __init__(
244244
) = None
245245
n_tests = N_TESTS_TO_GENERATE_EFFECTIVE
246246
self.executor = concurrent.futures.ThreadPoolExecutor(
247-
max_workers=n_tests + 2 if self.experiment_id is None else n_tests + 3
247+
max_workers=n_tests + 3 if self.experiment_id is None else n_tests + 4
248248
)
249249

250250
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
@@ -286,6 +286,7 @@ def generate_and_instrument_tests(
286286
list[Path],
287287
set[Path],
288288
dict | None,
289+
str,
289290
]
290291
]:
291292
"""Generate and instrument tests, returning all necessary data for optimization."""
@@ -323,9 +324,14 @@ def generate_and_instrument_tests(
323324

324325
generated_tests: GeneratedTestsList
325326
optimizations_set: OptimizationSet
326-
count_tests, generated_tests, function_to_concolic_tests, concolic_test_str, optimizations_set = (
327-
generated_results.unwrap()
328-
)
327+
(
328+
count_tests,
329+
generated_tests,
330+
function_to_concolic_tests,
331+
concolic_test_str,
332+
optimizations_set,
333+
function_references,
334+
) = generated_results.unwrap()
329335

330336
for i, generated_test in enumerate(generated_tests.generated_tests):
331337
with generated_test.behavior_file_path.open("w", encoding="utf8") as f:
@@ -371,6 +377,7 @@ def generate_and_instrument_tests(
371377
generated_perf_test_paths,
372378
instrumented_unittests_created_for_function,
373379
original_conftest_content,
380+
function_references,
374381
)
375382
)
376383

@@ -403,6 +410,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
403410
generated_perf_test_paths,
404411
instrumented_unittests_created_for_function,
405412
original_conftest_content,
413+
function_references,
406414
) = test_setup_result.unwrap()
407415

408416
baseline_setup_result = self.setup_and_establish_baseline(
@@ -437,6 +445,7 @@ def optimize_function(self) -> Result[BestOptimization, str]:
437445
generated_tests=generated_tests,
438446
test_functions_to_remove=test_functions_to_remove,
439447
concolic_test_str=concolic_test_str,
448+
function_references=function_references,
440449
)
441450

442451
# Add function to code context hash if in gh actions
@@ -458,6 +467,7 @@ def determine_best_candidate(
458467
original_helper_code: dict[Path, str],
459468
file_path_to_helper_classes: dict[Path, set[str]],
460469
exp_type: str,
470+
function_references: str,
461471
) -> BestOptimization | None:
462472
best_optimization: BestOptimization | None = None
463473
_best_runtime_until_now = original_code_baseline.runtime
@@ -667,6 +677,7 @@ def determine_best_candidate(
667677
else self.function_trace_id,
668678
ai_service_client=ai_service_client,
669679
executor=self.executor,
680+
function_references=function_references,
670681
)
671682
)
672683
else:
@@ -753,6 +764,7 @@ def determine_best_candidate(
753764
optimization_ids=optimization_ids,
754765
speedups=speedups_list,
755766
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
767+
function_references=function_references,
756768
)
757769
concurrent.futures.wait([future_ranking])
758770
ranking = future_ranking.result()
@@ -766,7 +778,7 @@ def determine_best_candidate(
766778
min_key = min(overall_ranking, key=overall_ranking.get)
767779
elif len(optimization_ids) == 1:
768780
min_key = 0 # only one candidate in valid _opts, already returns if there are no valid candidates
769-
else: # 0? shouldn't happen but it's there to escape potential bugs
781+
else: # 0? shouldn't happen, but it's there to escape potential bugs
770782
return None
771783
best_optimization = valid_candidates_with_shorter_code[min_key]
772784
# reassign code string which is the shortest
@@ -790,6 +802,7 @@ def refine_optimizations(
790802
trace_id: str,
791803
ai_service_client: AiServiceClient,
792804
executor: concurrent.futures.ThreadPoolExecutor,
805+
function_references: str | None = None,
793806
) -> concurrent.futures.Future:
794807
request = [
795808
AIServiceRefinerRequest(
@@ -804,6 +817,7 @@ def refine_optimizations(
804817
trace_id=trace_id,
805818
original_line_profiler_results=original_code_baseline.line_profile_results["str_out"],
806819
optimized_line_profiler_results=opt.line_profiler_test_results["str_out"],
820+
function_references=function_references,
807821
)
808822
for opt in valid_optimizations
809823
]
@@ -1089,7 +1103,7 @@ def generate_tests_and_optimizations(
10891103
generated_test_paths: list[Path],
10901104
generated_perf_test_paths: list[Path],
10911105
run_experiment: bool = False, # noqa: FBT001, FBT002
1092-
) -> Result[tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet], str]:
1106+
) -> Result[tuple[GeneratedTestsList, dict[str, set[FunctionCalledInTest]], OptimizationSet], str, str]:
10931107
n_tests = N_TESTS_TO_GENERATE_EFFECTIVE
10941108
assert len(generated_test_paths) == n_tests
10951109
console.rule()
@@ -1116,7 +1130,15 @@ def generate_tests_and_optimizations(
11161130
future_concolic_tests = self.executor.submit(
11171131
generate_concolic_tests, self.test_cfg, self.args, self.function_to_optimize, self.function_to_optimize_ast
11181132
)
1119-
futures = [*future_tests, future_optimization_candidates, future_concolic_tests]
1133+
future_references = self.executor.submit(
1134+
get_opt_review_metrics,
1135+
self.function_to_optimize_source_code,
1136+
self.function_to_optimize.file_path,
1137+
self.function_to_optimize.qualified_name,
1138+
self.project_root,
1139+
self.test_cfg.tests_root,
1140+
)
1141+
futures = [*future_tests, future_optimization_candidates, future_concolic_tests, future_references]
11201142
if run_experiment:
11211143
future_candidates_exp = self.executor.submit(
11221144
self.local_aiservice_client.optimize_python_code,
@@ -1168,7 +1190,7 @@ def generate_tests_and_optimizations(
11681190
logger.warning(f"Failed to generate and instrument tests for {self.function_to_optimize.function_name}")
11691191
return Failure(f"/!\\ NO TESTS GENERATED for {self.function_to_optimize.function_name}")
11701192
function_to_concolic_tests, concolic_test_str = future_concolic_tests.result()
1171-
1193+
function_references = future_references.result()
11721194
count_tests = len(tests)
11731195
if concolic_test_str:
11741196
count_tests += 1
@@ -1182,6 +1204,7 @@ def generate_tests_and_optimizations(
11821204
function_to_concolic_tests,
11831205
concolic_test_str,
11841206
OptimizationSet(control=candidates, experiment=candidates_experiment),
1207+
function_references,
11851208
)
11861209
self.generate_and_instrument_tests_results = result
11871210
return Success(result)
@@ -1263,6 +1286,7 @@ def find_and_process_best_optimization(
12631286
generated_tests: GeneratedTestsList,
12641287
test_functions_to_remove: list[str],
12651288
concolic_test_str: str | None,
1289+
function_references: str,
12661290
) -> BestOptimization | None:
12671291
"""Find the best optimization candidate and process it with all required steps."""
12681292
best_optimization = None
@@ -1279,6 +1303,7 @@ def find_and_process_best_optimization(
12791303
original_helper_code=original_helper_code,
12801304
file_path_to_helper_classes=file_path_to_helper_classes,
12811305
exp_type=exp_type,
1306+
function_references=function_references,
12821307
)
12831308
ph(
12841309
"cli-optimize-function-finished",
@@ -1347,6 +1372,7 @@ def find_and_process_best_optimization(
13471372
exp_type,
13481373
original_helper_code,
13491374
code_context,
1375+
function_references,
13501376
)
13511377
return best_optimization
13521378

@@ -1364,6 +1390,7 @@ def process_review(
13641390
exp_type: str,
13651391
original_helper_code: dict[Path, str],
13661392
code_context: CodeOptimizationContext,
1393+
function_references: str,
13671394
) -> None:
13681395
coverage_message = (
13691396
original_code_baseline.coverage_results.build_message()
@@ -1430,6 +1457,7 @@ def process_review(
14301457
original_throughput=original_throughput_str,
14311458
optimized_throughput=optimized_throughput_str,
14321459
throughput_improvement=throughput_improvement_str,
1460+
function_references=function_references,
14331461
)
14341462
new_explanation = Explanation(
14351463
raw_explanation_message=new_explanation_raw_str or explanation.raw_explanation_message,
@@ -1466,16 +1494,9 @@ def process_review(
14661494
opt_review_response = ""
14671495
if raise_pr or staging_review:
14681496
data["root_dir"] = git_root_dir()
1469-
calling_fn_details = get_opt_review_metrics(
1470-
self.function_to_optimize_source_code,
1471-
self.function_to_optimize.file_path,
1472-
self.function_to_optimize.qualified_name,
1473-
self.project_root,
1474-
self.test_cfg.tests_root,
1475-
)
14761497
try:
14771498
opt_review_response = self.aiservice_client.get_optimization_review(
1478-
**data, calling_fn_details=calling_fn_details
1499+
**data, calling_fn_details=function_references
14791500
)
14801501
except Exception as e:
14811502
logger.debug(f"optimization review response failed, investigate {e}")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ tests = [
9494
"scipy>=1.13.1",
9595
"torch>=2.8.0",
9696
"xarray>=2024.7.0",
97+
"eval_type_backport"
9798
]
9899

99100
[tool.hatch.build.targets.sdist]

0 commit comments

Comments
 (0)