|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import os |
3 | 4 | from pathlib import Path |
4 | 5 | from typing import TYPE_CHECKING, Optional |
5 | 6 |
|
|
16 | 17 | git_root_dir, |
17 | 18 | ) |
18 | 19 | from codeflash.code_utils.github_utils import github_pr_url |
| 20 | +from codeflash.code_utils.time_utils import format_time |
19 | 21 | from codeflash.github.PrComment import FileDiffContent, PrComment |
20 | 22 |
|
21 | 23 | if TYPE_CHECKING: |
22 | | - from codeflash.models.models import FunctionCalledInTest, TestResults |
| 24 | + from codeflash.models.models import FunctionCalledInTest |
23 | 25 | from codeflash.result.explanation import Explanation |
| 26 | + from codeflash.verification.verification_utils import TestConfig |
24 | 27 |
|
25 | 28 |
|
26 | 29 | def existing_tests_source_for( |
27 | 30 | function_qualified_name_with_modules_from_root: str, |
28 | 31 | function_to_tests: dict[str, set[FunctionCalledInTest]], |
29 | | - tests_root: Path, |
30 | | - original_test_results: Optional[TestResults] = None, |
31 | | - optimized_test_results: Optional[TestResults] = None, |
| 32 | + test_cfg: TestConfig, |
| 33 | + original_runtimes_all: dict, |
| 34 | + optimized_runtimes_all: dict, |
32 | 35 | ) -> str: |
33 | 36 | test_files = function_to_tests.get(function_qualified_name_with_modules_from_root) |
34 | 37 | if not test_files: |
35 | 38 | return "" |
36 | | - existing_tests_unique = set() |
37 | | - # a lot of loops, need to do in a single loop |
38 | | - #original_runtime_by_test = original_test_results.usable_runtime_data_by_test_case() |
39 | | - #optimized_runtime_by_test = optimized_test_results.usable_runtime_data_by_test_case() |
40 | | - # Group test cases by test file |
41 | | - test_files_grouped = {} |
42 | | - for test_file in test_files: |
43 | | - file_path = Path(test_file.tests_in_file.test_file) |
44 | | - relative_path = str(file_path.relative_to(tests_root)) |
45 | | - |
46 | | - if relative_path not in test_files_grouped: |
47 | | - test_files_grouped[relative_path] = [] |
48 | | - test_files_grouped.setdefault(relative_path,[]).append(test_file) |
49 | | - |
50 | | - # Create detailed report for each test file |
51 | | - # for relative_path, tests_in_file in sorted(test_files_grouped.items()): |
52 | | - file_line = f"- {relative_path}" |
53 | | - |
54 | | - # Add test case details with timing information if available |
55 | | - #if original_test_results and optimized_test_results: |
56 | | - test_case_details = [] |
57 | | - # Collect test function names for this file |
58 | | - test_functions_in_file = {test_file.tests_in_file.test_function for test_file in tests_in_file} |
59 | | - |
60 | | - # Create timing report for each test function |
61 | | - for test_function_name in sorted(test_functions_in_file): |
62 | | - # Find matching runtime data |
63 | | - original_runtimes = [] |
64 | | - optimized_runtimes = [] |
65 | | - |
66 | | - for invocation_id, runtimes in original_runtime_by_test.items(): |
67 | | - if invocation_id.test_function_name == test_function_name: |
68 | | - original_runtimes.extend(runtimes) |
69 | | - |
70 | | - for invocation_id, runtimes in optimized_runtime_by_test.items(): |
71 | | - if invocation_id.test_function_name == test_function_name: |
72 | | - optimized_runtimes.extend(runtimes) |
73 | | - |
74 | | - if original_runtimes and optimized_runtimes: |
75 | | - # Use minimum timing like the generated tests function does |
76 | | - original_time = min(original_runtimes) |
77 | | - optimized_time = min(optimized_runtimes) |
78 | | - |
79 | | - from codeflash.code_utils.time_utils import format_time |
80 | | - |
81 | | - original_str = format_time(original_time) |
82 | | - optimized_str = format_time(optimized_time) |
83 | | - |
84 | | - test_case_details.append(f" - {test_function_name}: {original_str} -> {optimized_str}") |
85 | | - |
86 | | - if test_case_details: |
87 | | - file_line += "\n" + "\n".join(test_case_details) |
88 | | - |
89 | | - existing_tests_unique.add(file_line) |
90 | | - |
91 | | - return "\n".join(sorted(existing_tests_unique)) |
| 39 | + output = "" |
| 40 | + tests_root = test_cfg.tests_root |
| 41 | + module_root = test_cfg.project_root_path |
| 42 | + rel_tests_root = tests_root.relative_to(module_root) |
| 43 | + original_tests_to_runtimes = {} |
| 44 | + optimized_tests_to_runtimes = {} |
| 45 | + # TODO confirm that original and optimized have the same keys |
| 46 | + all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys() |
| 47 | + for invocation_id in all_invocation_ids: |
| 48 | + rel_path = ( |
| 49 | + Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").relative_to(rel_tests_root) |
| 50 | + ) |
| 51 | + if rel_path not in original_tests_to_runtimes: |
| 52 | + original_tests_to_runtimes[rel_path] = {} |
| 53 | + if rel_path not in optimized_tests_to_runtimes: |
| 54 | + optimized_tests_to_runtimes[rel_path] = {} |
| 55 | + qualified_name = ( |
| 56 | + invocation_id.test_class_name + "." + invocation_id.test_function_name |
| 57 | + if invocation_id.test_class_name |
| 58 | + else invocation_id.test_function_name |
| 59 | + ) |
| 60 | + if qualified_name not in original_tests_to_runtimes[rel_path]: |
| 61 | + original_tests_to_runtimes[rel_path][qualified_name] = 0 |
| 62 | + if qualified_name not in optimized_tests_to_runtimes[rel_path]: |
| 63 | + optimized_tests_to_runtimes[rel_path][qualified_name] = 0 |
| 64 | + if invocation_id in original_runtimes_all: |
| 65 | + original_tests_to_runtimes[rel_path][qualified_name] += min(original_runtimes_all[invocation_id]) |
| 66 | + if invocation_id in optimized_runtimes_all: |
| 67 | + optimized_tests_to_runtimes[rel_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) |
| 68 | + # parse into string |
| 69 | + all_rel_paths = ( |
| 70 | + original_tests_to_runtimes.keys() |
| 71 | + ) # both will have the same keys as some default values are assigned in the previous loop |
| 72 | + for filename in sorted(all_rel_paths): |
| 73 | + output += f"- {filename}\n" |
| 74 | + all_qualified_names = original_tests_to_runtimes[ |
| 75 | + filename |
| 76 | + ].keys() # both will have the same keys as some default values are assigned in the previous loop |
| 77 | + for qualified_name in sorted(all_qualified_names): |
| 78 | + # if not present in optimized output nan |
| 79 | + if optimized_tests_to_runtimes[filename][qualified_name] == 0: |
| 80 | + print_optimized_runtime = "NaN" |
| 81 | + else: |
| 82 | + print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name]) |
| 83 | + if original_tests_to_runtimes[filename][qualified_name] == 0: |
| 84 | + print_original_runtime = "NaN" |
| 85 | + else: |
| 86 | + print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name]) |
| 87 | + output += f" - {qualified_name}: {print_original_runtime} -> {print_optimized_runtime}\n" |
| 88 | + output += "\n" |
| 89 | + return output |
92 | 90 |
|
93 | 91 |
|
94 | 92 | def check_create_pr( |
|
0 commit comments