Skip to content

Commit a09d11c

Browse files
committed
works
1 parent 0332d2b commit a09d11c

File tree

3 files changed

+72
-73
lines changed

3 files changed

+72
-73
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from codeflash.cli_cmds.console import logger
66
from codeflash.code_utils.time_utils import format_time
7-
from codeflash.models.models import GeneratedTests, GeneratedTestsList, TestResults
7+
from codeflash.models.models import GeneratedTests, GeneratedTestsList
88

99

1010
def remove_functions_from_generated_tests(
@@ -33,12 +33,9 @@ def remove_functions_from_generated_tests(
3333

3434

3535
def add_runtime_comments_to_generated_tests(
36-
generated_tests: GeneratedTestsList, original_test_results: TestResults, optimized_test_results: TestResults
36+
generated_tests: GeneratedTestsList, original_runtimes: dict, optimized_runtimes: dict
3737
) -> GeneratedTestsList:
3838
"""Add runtime performance comments to function calls in generated tests."""
39-
# Create dictionaries for fast lookup of runtime data
40-
original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
41-
optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
4239

4340
class RuntimeCommentTransformer(cst.CSTTransformer):
4441
def __init__(self) -> None:
@@ -84,11 +81,11 @@ def leave_SimpleStatementLine(
8481
matching_original_times = []
8582
matching_optimized_times = []
8683

87-
for invocation_id, runtimes in original_runtime_by_test.items():
84+
for invocation_id, runtimes in original_runtimes.items():
8885
if invocation_id.test_function_name == self.current_test_name:
8986
matching_original_times.extend(runtimes)
9087

91-
for invocation_id, runtimes in optimized_runtime_by_test.items():
88+
for invocation_id, runtimes in optimized_runtimes.items():
9289
if invocation_id.test_function_name == self.current_test_name:
9390
matching_optimized_times.extend(runtimes)
9491

codeflash/optimization/function_optimizer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -354,21 +354,25 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
354354
generated_tests = remove_functions_from_generated_tests(
355355
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
356356
)
357+
original_runtime_by_test = (
358+
original_code_baseline.benchmarking_test_results.usable_runtime_data_by_test_case()
359+
)
360+
optimized_runtime_by_test = (
361+
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
362+
)
357363
# Add runtime comments to generated tests before creating the PR
358364
generated_tests = add_runtime_comments_to_generated_tests(
359-
generated_tests,
360-
original_code_baseline.benchmarking_test_results,
361-
best_optimization.winning_benchmarking_test_results,
365+
generated_tests, original_runtime_by_test, optimized_runtime_by_test
362366
)
363367
generated_tests_str = "\n\n".join(
364368
[test.generated_original_test_source for test in generated_tests.generated_tests]
365369
)
366370
existing_tests = existing_tests_source_for(
367371
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
368372
function_to_all_tests,
369-
tests_root=self.test_cfg.tests_root,
370-
original_test_results=original_code_baseline.benchmarking_test_results,
371-
optimized_test_results=best_optimization.winning_benchmarking_test_results,
373+
test_cfg=self.test_cfg,
374+
original_runtimes_all=original_runtime_by_test,
375+
optimized_runtimes_all=optimized_runtime_by_test,
372376
)
373377
if concolic_test_str:
374378
generated_tests_str += "\n\n" + concolic_test_str

codeflash/result/create_pr.py

Lines changed: 58 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import os
34
from pathlib import Path
45
from typing import TYPE_CHECKING, Optional
56

@@ -16,79 +17,76 @@
1617
git_root_dir,
1718
)
1819
from codeflash.code_utils.github_utils import github_pr_url
20+
from codeflash.code_utils.time_utils import format_time
1921
from codeflash.github.PrComment import FileDiffContent, PrComment
2022

2123
if TYPE_CHECKING:
22-
from codeflash.models.models import FunctionCalledInTest, TestResults
24+
from codeflash.models.models import FunctionCalledInTest
2325
from codeflash.result.explanation import Explanation
26+
from codeflash.verification.verification_utils import TestConfig
2427

2528

2629
def existing_tests_source_for(
2730
function_qualified_name_with_modules_from_root: str,
2831
function_to_tests: dict[str, set[FunctionCalledInTest]],
29-
tests_root: Path,
30-
original_test_results: Optional[TestResults] = None,
31-
optimized_test_results: Optional[TestResults] = None,
32+
test_cfg: TestConfig,
33+
original_runtimes_all: dict,
34+
optimized_runtimes_all: dict,
3235
) -> str:
3336
test_files = function_to_tests.get(function_qualified_name_with_modules_from_root)
3437
if not test_files:
3538
return ""
36-
existing_tests_unique = set()
37-
# a lot of loops, need to do in a single loop
38-
#original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case()
39-
#optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case()
40-
# Group test cases by test file
41-
test_files_grouped = {}
42-
for test_file in test_files:
43-
file_path = Path(test_file.tests_in_file.test_file)
44-
relative_path = str(file_path.relative_to(tests_root))
45-
46-
if relative_path not in test_files_grouped:
47-
test_files_grouped[relative_path] = []
48-
test_files_grouped.setdefault(relative_path,[]).append(test_file)
49-
50-
# Create detailed report for each test file
51-
# for relative_path, tests_in_file in sorted(test_files_grouped.items()):
52-
file_line = f"- {relative_path}"
53-
54-
# Add test case details with timing information if available
55-
#if original_test_results and optimized_test_results:
56-
test_case_details = []
57-
# Collect test function names for this file
58-
test_functions_in_file = {test_file.tests_in_file.test_function for test_file in tests_in_file}
59-
60-
# Create timing report for each test function
61-
for test_function_name in sorted(test_functions_in_file):
62-
# Find matching runtime data
63-
original_runtimes = []
64-
optimized_runtimes = []
65-
66-
for invocation_id, runtimes in original_runtime_by_test.items():
67-
if invocation_id.test_function_name == test_function_name:
68-
original_runtimes.extend(runtimes)
69-
70-
for invocation_id, runtimes in optimized_runtime_by_test.items():
71-
if invocation_id.test_function_name == test_function_name:
72-
optimized_runtimes.extend(runtimes)
73-
74-
if original_runtimes and optimized_runtimes:
75-
# Use minimum timing like the generated tests function does
76-
original_time = min(original_runtimes)
77-
optimized_time = min(optimized_runtimes)
78-
79-
from codeflash.code_utils.time_utils import format_time
80-
81-
original_str = format_time(original_time)
82-
optimized_str = format_time(optimized_time)
83-
84-
test_case_details.append(f" - {test_function_name}: {original_str} -> {optimized_str}")
85-
86-
if test_case_details:
87-
file_line += "\n" + "\n".join(test_case_details)
88-
89-
existing_tests_unique.add(file_line)
90-
91-
return "\n".join(sorted(existing_tests_unique))
39+
output = ""
40+
tests_root = test_cfg.tests_root
41+
module_root = test_cfg.project_root_path
42+
rel_tests_root = tests_root.relative_to(module_root)
43+
original_tests_to_runtimes = {}
44+
optimized_tests_to_runtimes = {}
45+
# TODO confirm that original and optimized have the same keys
46+
all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys()
47+
for invocation_id in all_invocation_ids:
48+
rel_path = (
49+
Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").relative_to(rel_tests_root)
50+
)
51+
if rel_path not in original_tests_to_runtimes:
52+
original_tests_to_runtimes[rel_path] = {}
53+
if rel_path not in optimized_tests_to_runtimes:
54+
optimized_tests_to_runtimes[rel_path] = {}
55+
qualified_name = (
56+
invocation_id.test_class_name + "." + invocation_id.test_function_name
57+
if invocation_id.test_class_name
58+
else invocation_id.test_function_name
59+
)
60+
if qualified_name not in original_tests_to_runtimes[rel_path]:
61+
original_tests_to_runtimes[rel_path][qualified_name] = 0
62+
if qualified_name not in optimized_tests_to_runtimes[rel_path]:
63+
optimized_tests_to_runtimes[rel_path][qualified_name] = 0
64+
if invocation_id in original_runtimes_all:
65+
original_tests_to_runtimes[rel_path][qualified_name] += min(original_runtimes_all[invocation_id])
66+
if invocation_id in optimized_runtimes_all:
67+
optimized_tests_to_runtimes[rel_path][qualified_name] += min(optimized_runtimes_all[invocation_id])
68+
# parse into string
69+
all_rel_paths = (
70+
original_tests_to_runtimes.keys()
71+
) # both will have the same keys as some default values are assigned in the previous loop
72+
for filename in sorted(all_rel_paths):
73+
output += f"- {filename}\n"
74+
all_qualified_names = original_tests_to_runtimes[
75+
filename
76+
].keys() # both will have the same keys as some default values are assigned in the previous loop
77+
for qualified_name in sorted(all_qualified_names):
78+
# if not present in optimized output nan
79+
if optimized_tests_to_runtimes[filename][qualified_name] == 0:
80+
print_optimized_runtime = "NaN"
81+
else:
82+
print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name])
83+
if original_tests_to_runtimes[filename][qualified_name] == 0:
84+
print_original_runtime = "NaN"
85+
else:
86+
print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name])
87+
output += f" - {qualified_name}: {print_original_runtime} -> {print_optimized_runtime}\n"
88+
output += "\n"
89+
return output
9290

9391

9492
def check_create_pr(

0 commit comments

Comments
 (0)