Skip to content

Commit 6874dcb

Browse files
committed
linting mypy fixes
1 parent 66207b5 commit 6874dcb

File tree

4 files changed

+109
-94
lines changed

4 files changed

+109
-94
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -159,19 +159,10 @@ def leave_SimpleStatementLine(
159159
if invocation_id.test_class_name
160160
else invocation_id.test_function_name
161161
)
162-
rel_path = (
163-
Path(invocation_id.test_module_path.replace(".", os.sep))
164-
.with_suffix(".py")
165-
.resolve()
166-
.relative_to(self.tests_root)
167-
)
162+
abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve()
168163
if (
169164
qualified_name == ".".join(self.context_stack)
170-
and rel_path
171-
in [
172-
self.test.behavior_file_path.relative_to(self.tests_root),
173-
self.test.perf_file_path.relative_to(self.tests_root),
174-
]
165+
and abs_path in [self.test.behavior_file_path, self.test.perf_file_path]
175166
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
176167
):
177168
matching_original_times.extend(runtimes)
@@ -183,19 +174,10 @@ def leave_SimpleStatementLine(
183174
if invocation_id.test_class_name
184175
else invocation_id.test_function_name
185176
)
186-
rel_path = (
187-
Path(invocation_id.test_module_path.replace(".", os.sep))
188-
.with_suffix(".py")
189-
.resolve()
190-
.relative_to(self.tests_root)
191-
)
177+
abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve()
192178
if (
193179
qualified_name == ".".join(self.context_stack)
194-
and rel_path
195-
in [
196-
self.test.behavior_file_path.relative_to(self.tests_root),
197-
self.test.perf_file_path.relative_to(self.tests_root),
198-
]
180+
and abs_path in [self.test.behavior_file_path, self.test.perf_file_path]
199181
and int(invocation_id.iteration_id.split("_")[0]) == self.cfo_locs[self.cfo_idx_loc_to_look_at] # type:ignore[union-attr]
200182
):
201183
matching_optimized_times.extend(runtimes)

codeflash/result/create_pr.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -46,40 +46,35 @@ def existing_tests_source_for(
4646
optimized_tests_to_runtimes: dict[Path, dict[str, int]] = {}
4747
non_generated_tests = set()
4848
for test_file in test_files:
49-
non_generated_tests.add(Path(test_file.tests_in_file.test_file).relative_to(tests_root))
49+
non_generated_tests.add(test_file.tests_in_file.test_file)
5050
# TODO confirm that original and optimized have the same keys
5151
all_invocation_ids = original_runtimes_all.keys() | optimized_runtimes_all.keys()
5252
for invocation_id in all_invocation_ids:
53-
rel_path = (
54-
Path(invocation_id.test_module_path.replace(".", os.sep))
55-
.with_suffix(".py")
56-
.resolve()
57-
.relative_to(tests_root)
58-
)
59-
if rel_path not in non_generated_tests:
53+
abs_path = Path(invocation_id.test_module_path.replace(".", os.sep)).with_suffix(".py").resolve()
54+
if abs_path not in non_generated_tests:
6055
continue
61-
if rel_path not in original_tests_to_runtimes:
62-
original_tests_to_runtimes[rel_path] = {}
63-
if rel_path not in optimized_tests_to_runtimes:
64-
optimized_tests_to_runtimes[rel_path] = {}
56+
if abs_path not in original_tests_to_runtimes:
57+
original_tests_to_runtimes[abs_path] = {}
58+
if abs_path not in optimized_tests_to_runtimes:
59+
optimized_tests_to_runtimes[abs_path] = {}
6560
qualified_name = (
6661
invocation_id.test_class_name + "." + invocation_id.test_function_name # type: ignore[operator]
6762
if invocation_id.test_class_name
6863
else invocation_id.test_function_name
6964
)
70-
if qualified_name not in original_tests_to_runtimes[rel_path]:
71-
original_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index]
72-
if qualified_name not in optimized_tests_to_runtimes[rel_path]:
73-
optimized_tests_to_runtimes[rel_path][qualified_name] = 0 # type: ignore[index]
65+
if qualified_name not in original_tests_to_runtimes[abs_path]:
66+
original_tests_to_runtimes[abs_path][qualified_name] = 0 # type: ignore[index]
67+
if qualified_name not in optimized_tests_to_runtimes[abs_path]:
68+
optimized_tests_to_runtimes[abs_path][qualified_name] = 0 # type: ignore[index]
7469
if invocation_id in original_runtimes_all:
75-
original_tests_to_runtimes[rel_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index]
70+
original_tests_to_runtimes[abs_path][qualified_name] += min(original_runtimes_all[invocation_id]) # type: ignore[index]
7671
if invocation_id in optimized_runtimes_all:
77-
optimized_tests_to_runtimes[rel_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index]
72+
optimized_tests_to_runtimes[abs_path][qualified_name] += min(optimized_runtimes_all[invocation_id]) # type: ignore[index]
7873
# parse into string
79-
all_rel_paths = (
74+
all_abs_paths = (
8075
original_tests_to_runtimes.keys()
8176
) # both will have the same keys as some default values are assigned in the previous loop
82-
for filename in sorted(all_rel_paths):
77+
for filename in sorted(all_abs_paths):
8378
all_qualified_names = original_tests_to_runtimes[
8479
filename
8580
].keys() # both will have the same keys as some default values are assigned in the previous loop
@@ -91,6 +86,7 @@ def existing_tests_source_for(
9186
):
9287
print_optimized_runtime = format_time(optimized_tests_to_runtimes[filename][qualified_name])
9388
print_original_runtime = format_time(original_tests_to_runtimes[filename][qualified_name])
89+
print_filename = filename.relative_to(tests_root)
9490
greater = (
9591
optimized_tests_to_runtimes[filename][qualified_name]
9692
> original_tests_to_runtimes[filename][qualified_name]
@@ -105,7 +101,7 @@ def existing_tests_source_for(
105101
if greater:
106102
rows.append(
107103
[
108-
f"`{filename}::{qualified_name}`",
104+
f"`{print_filename}::{qualified_name}`",
109105
f"{print_original_runtime}",
110106
f"{print_optimized_runtime}",
111107
f"⚠️{perf_gain}%",
@@ -114,7 +110,7 @@ def existing_tests_source_for(
114110
else:
115111
rows.append(
116112
[
117-
f"`{filename}::{qualified_name}`",
113+
f"`{print_filename}::{qualified_name}`",
118114
f"{print_original_runtime}",
119115
f"{print_optimized_runtime}",
120116
f"✅{perf_gain}%",

0 commit comments

Comments
 (0)