Skip to content

Commit de84558

Browse files
committed
Merge branch 'main' of https://github.com/codeflash-ai/codeflash into part-1-windows-fixes
2 parents 274f421 + ed6ffe4 commit de84558

File tree

7 files changed

+1129
-362
lines changed

7 files changed

+1129
-362
lines changed

codeflash/cli_cmds/cmd_init.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,7 @@ def collect_setup_info() -> SetupInfo:
231231
info_panel = Panel(
232232
Text(
233233
"📁 Let's identify your Python module directory.\n\n"
234-
"This is usually the top-level directory containing all your Python source code.\n"
235-
"We've automatically detected some directories for you.",
234+
"This is usually the top-level directory containing all your Python source code.\n",
236235
style="cyan",
237236
),
238237
title="🔍 Module Discovery",
@@ -936,7 +935,6 @@ def configure_pyproject_toml(setup_info: SetupInfo) -> None:
936935
codeflash_section["module-root"] = setup_info.module_root
937936
codeflash_section["tests-root"] = setup_info.tests_root
938937
codeflash_section["test-framework"] = setup_info.test_framework
939-
codeflash_section["benchmarks-root"] = setup_info.benchmarks_root if setup_info.benchmarks_root else ""
940938
codeflash_section["ignore-paths"] = setup_info.ignore_paths
941939
codeflash_section["disable-telemetry"] = not enable_telemetry
942940
if setup_info.git_remote not in ["", "origin"]:

codeflash/code_utils/edit_generated_tests.py

Lines changed: 156 additions & 205 deletions
Large diffs are not rendered by default.

codeflash/discovery/discover_unit_tests.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -624,45 +624,47 @@ def process_test_files(
624624
except Exception as e:
625625
logger.debug(str(e))
626626
continue
627-
628-
if not definition or definition[0].type != "function":
629-
continue
630-
631-
definition_obj = definition[0]
632-
definition_path = str(definition_obj.module_path)
633-
634-
project_root_str = str(project_root_path)
635-
if (
636-
definition_path.startswith(project_root_str + os.sep)
637-
and definition_obj.module_name != name.module_name
638-
and definition_obj.full_name is not None
639-
):
640-
# Pre-compute common values outside the inner loop
641-
module_prefix = definition_obj.module_name + "."
642-
full_name_without_module_prefix = definition_obj.full_name.replace(module_prefix, "", 1)
643-
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition_obj.module_path, project_root_path)}.{full_name_without_module_prefix}"
644-
645-
for test_func in test_functions_by_name[scope]:
646-
if test_func.parameters is not None:
647-
if test_framework == "pytest":
648-
scope_test_function = f"{test_func.function_name}[{test_func.parameters}]"
649-
else: # unittest
650-
scope_test_function = f"{test_func.function_name}_{test_func.parameters}"
651-
else:
652-
scope_test_function = test_func.function_name
653-
654-
function_to_test_map[qualified_name_with_modules_from_root].add(
655-
FunctionCalledInTest(
656-
tests_in_file=TestsInFile(
657-
test_file=test_file,
658-
test_class=test_func.test_class,
659-
test_function=scope_test_function,
660-
test_type=test_func.test_type,
661-
),
662-
position=CodePosition(line_no=name.line, col_no=name.column),
627+
try:
628+
if not definition or definition[0].type != "function":
629+
continue
630+
definition_obj = definition[0]
631+
definition_path = str(definition_obj.module_path)
632+
633+
project_root_str = str(project_root_path)
634+
if (
635+
definition_path.startswith(project_root_str + os.sep)
636+
and definition_obj.module_name != name.module_name
637+
and definition_obj.full_name is not None
638+
):
639+
# Pre-compute common values outside the inner loop
640+
module_prefix = definition_obj.module_name + "."
641+
full_name_without_module_prefix = definition_obj.full_name.replace(module_prefix, "", 1)
642+
qualified_name_with_modules_from_root = f"{module_name_from_file_path(definition_obj.module_path, project_root_path)}.{full_name_without_module_prefix}"
643+
644+
for test_func in test_functions_by_name[scope]:
645+
if test_func.parameters is not None:
646+
if test_framework == "pytest":
647+
scope_test_function = f"{test_func.function_name}[{test_func.parameters}]"
648+
else: # unittest
649+
scope_test_function = f"{test_func.function_name}_{test_func.parameters}"
650+
else:
651+
scope_test_function = test_func.function_name
652+
653+
function_to_test_map[qualified_name_with_modules_from_root].add(
654+
FunctionCalledInTest(
655+
tests_in_file=TestsInFile(
656+
test_file=test_file,
657+
test_class=test_func.test_class,
658+
test_function=scope_test_function,
659+
test_type=test_func.test_type,
660+
),
661+
position=CodePosition(line_no=name.line, col_no=name.column),
662+
)
663663
)
664-
)
665-
num_discovered_tests += 1
664+
num_discovered_tests += 1
665+
except Exception as e:
666+
logger.debug(str(e))
667+
continue
666668

667669
progress.advance(task_id)
668670

codeflash/optimization/function_optimizer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,15 +1012,20 @@ def find_and_process_best_optimization(
10121012
optimized_runtime_by_test = (
10131013
best_optimization.winning_benchmarking_test_results.usable_runtime_data_by_test_case()
10141014
)
1015+
qualified_name = self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root)
10151016
# Add runtime comments to generated tests before creating the PR
10161017
generated_tests = add_runtime_comments_to_generated_tests(
1017-
self.test_cfg, generated_tests, original_runtime_by_test, optimized_runtime_by_test
1018+
qualified_name,
1019+
self.test_cfg,
1020+
generated_tests,
1021+
original_runtime_by_test,
1022+
optimized_runtime_by_test,
10181023
)
10191024
generated_tests_str = "\n\n".join(
10201025
[test.generated_original_test_source for test in generated_tests.generated_tests]
10211026
)
10221027
existing_tests = existing_tests_source_for(
1023-
self.function_to_optimize.qualified_name_with_modules_from_root(self.project_root),
1028+
qualified_name,
10241029
function_to_all_tests,
10251030
test_cfg=self.test_cfg,
10261031
original_runtimes_all=original_runtime_by_test,

codeflash/result/create_pr.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,43 +42,39 @@ def existing_tests_source_for(
4242
rows = []
4343
headers = ["Test File::Test Function", "Original ⏱️", "Optimized ⏱️", "Speedup"]
4444
tests_root = test_cfg.tests_root
45-
module_root = test_cfg.project_root_path
46-
rel_tests_root = tests_root.relative_to(module_root)
4745
original_tests_to_runtimes: dict[Path, dict[str, int]] = {}
4846
optimized_tests_to_runtimes: dict[Path, dict[str, int]] = {}
4947
non_generated_tests = set()
5048
for test_file in test_files:
51-
non_generated_tests.add(Path(test_file.tests_in_file.test_file).relative_to(tests_root))
49+
non_generated_tests.add(test_file.tests_in_file.test_file)
5250
# TODO confirm that original and optimized have the same keys
5351
all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys()
5452
for invocation_id in all_invocation_ids:
55-
rel_path = (
56-
Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").relative_to(rel_tests_root)
57-
)
58-
if rel_path not in non_generated_tests:
53+
abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve()
54+
if abs_path not in non_generated_tests:
5955
continue
60-
if rel_path not in original_tests_to_runtimes:
61-
original_tests_to_runtimes[rel_path] = {}
62-
if rel_path not in optimized_tests_to_runtimes:
63-
optimized_tests_to_runtimes[rel_path] = {}
56+
if abs_path not in original_tests_to_runtimes:
57+
original_tests_to_runtimes[abs_path] = {}
58+
if abs_path not in optimized_tests_to_runtimes:
59+
optimized_tests_to_runtimes[abs_path] = {}
6460
qualified_name = (
6561
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
6662
if invocation_id.test_class_name
6763
else invocation_id.test_function_name
6864
)
69-
if qualified_name not in original_tests_to_runtimes[rel_path]:
70-
original_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index]
71-
if qualified_name not in optimized_tests_to_runtimes[rel_path]:
72-
optimized_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index]
65+
if qualified_name not in original_tests_to_runtimes[abs_path]:
66+
original_tests_to_runtimes[abs_path][qualified_name] = 0 # type: ignore[index]
67+
if qualified_name not in optimized_tests_to_runtimes[abs_path]:
68+
optimized_tests_to_runtimes[abs_path][qualified_name] = 0 # type: ignore[index]
7369
if invocation_id in original_runtimes_all:
74-
original_tests_to_runtimes[rel_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index]
70+
original_tests_to_runtimes[abs_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index]
7571
if invocation_id in optimized_runtimes_all:
76-
optimized_tests_to_runtimes[rel_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index]
72+
optimized_tests_to_runtimes[abs_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index]
7773
# parse into string
78-
all_rel_paths = (
74+
all_abs_paths = (
7975
original_tests_to_runtimes.keys()
8076
) # both will have the same keys as some default values are assigned in the previous loop
81-
for filename in sorted(all_rel_paths):
77+
for filename in sorted(all_abs_paths):
8278
all_qualified_names = original_tests_to_runtimes[
8379
filename
8480
].keys() # both will have the same keys as some default values are assigned in the previous loop
@@ -90,6 +86,7 @@ def existing_tests_source_for(
9086
):
9187
print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name])
9288
print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name])
89+
print_filename = filename.relative_to(tests_root)
9390
greater = (
9491
optimized_tests_to_runtimes[filename][qualified_name]
9592
> original_tests_to_runtimes[filename][qualified_name]
@@ -104,7 +101,7 @@ def existing_tests_source_for(
104101
if greater:
105102
rows.append(
106103
[
107-
f"`{filename.as_posix()}::{qualified_name}`",
104+
f"`{print_filename.as_posix()}::{qualified_name}`",
108105
f"{print_original_runtime}",
109106
f"{print_optimized_runtime}",
110107
f"⚠️{perf_gain}%",
@@ -113,7 +110,7 @@ def existing_tests_source_for(
113110
else:
114111
rows.append(
115112
[
116-
f"`{filename.as_posix()}::{qualified_name}`",
113+
f"`{print_filename.as_posix()}::{qualified_name}`",
117114
f"{print_original_runtime}",
118115
f"{print_optimized_runtime}",
119116
f"✅{perf_gain}%",

0 commit comments

Comments
 (0)