Skip to content

Commit f3cba99

Browse files
committed
fix runtime calculations
1 parent f885467 commit f3cba99

File tree

2 files changed

+15
-20
lines changed

2 files changed

+15
-20
lines changed

codeflash/code_utils/edit_generated_tests.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
class CommentMapper(ast.NodeVisitor):
2323
def __init__(
24-
self, test: GeneratedTests, original_runtimes: dict[str, list[int]], optimized_runtimes: dict[str, list[int]]
24+
self, test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
2525
) -> None:
2626
self.results: dict[int, str] = {}
2727
self.test: GeneratedTests = test
@@ -56,8 +56,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
5656
match_key = key + "#" + inv_id
5757
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
5858
# calculate speedup and output comment
59-
original_time = min(self.original_runtimes[match_key])
60-
optimized_time = min(self.optimized_runtimes[match_key])
59+
original_time = self.original_runtimes[match_key]
60+
optimized_time = self.optimized_runtimes[match_key]
6161
perf_gain = format_perf(
6262
abs(
6363
performance_gain(
@@ -76,8 +76,8 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
7676
match_key = key + "#" + inv_id
7777
if match_key in self.original_runtimes and match_key in self.optimized_runtimes:
7878
# calculate speedup and output comment
79-
original_time = min(self.original_runtimes[match_key])
80-
optimized_time = min(self.optimized_runtimes[match_key])
79+
original_time = self.original_runtimes[match_key]
80+
optimized_time = self.optimized_runtimes[match_key]
8181
perf_gain = format_perf(
8282
abs(
8383
performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time)
@@ -96,7 +96,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
9696

9797

9898
def get_fn_call_linenos(
99-
test: GeneratedTests, original_runtimes: dict[str, list[int]], optimized_runtimes: dict[str, list[int]]
99+
test: GeneratedTests, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]
100100
) -> dict[int, str]:
101101
line_comment_ast_mapper = CommentMapper(test, original_runtimes, optimized_runtimes)
102102
source_code = test.generated_original_test_source
@@ -156,8 +156,8 @@ def leave_SimpleStatementSuite(
156156
return updated_node
157157

158158

159-
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, list[int]]:
160-
unique_inv_ids: dict[str, list[int]] = {}
159+
def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, int]:
160+
unique_inv_ids: dict[str, int] = {}
161161
for inv_id, runtimes in inv_id_runtimes.items():
162162
test_qualified_name = (
163163
inv_id.test_class_name + "." + inv_id.test_function_name # type: ignore[operator]
@@ -172,8 +172,8 @@ def unique_inv_id(inv_id_runtimes: dict[InvocationId, list[int]]) -> dict[str, l
172172
cur_invid = inv_id.iteration_id.split("_")[0] if parts < 3 else "_".join(inv_id.iteration_id.split("_")[:-1]) # type: ignore[union-attr]
173173
match_key = key + "#" + cur_invid
174174
if match_key not in unique_inv_ids:
175-
unique_inv_ids[match_key] = []
176-
unique_inv_ids[match_key].extend(runtimes)
175+
unique_inv_ids[match_key] = 0
176+
unique_inv_ids[match_key] += min(runtimes)
177177
return unique_inv_ids
178178

179179

tests/test_add_runtime_comments.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,7 +1619,7 @@ def test_runtime_comment_addition_for(self, test_config):
16191619
for i in range(3):
16201620
b = 3
16211621
b1 = 6
1622-
codeflash_output = bubble_sort([3, 1, 2]) # 500μs -> 300μs (66.7% faster)
1622+
codeflash_output = bubble_sort([3, 1, 2]) # 1.80ms -> 1.20ms (50.0% faster)
16231623
assert codeflash_output == [1, 2, 3]
16241624
c = 4
16251625
d = 5
@@ -1640,7 +1640,6 @@ def test_runtime_comment_addition_for(self, test_config):
16401640
# Add test invocations with different runtimes
16411641
original_invocation1 = self.create_test_invocation("test_bubble_sort", 500_000, iteration_id='1_2_0') # 500μs
16421642
optimized_invocation1 = self.create_test_invocation("test_bubble_sort", 300_000, iteration_id='1_2_0') # 300μs
1643-
# longer runtime than minimum, will not contribute
16441643
original_invocation2 = self.create_test_invocation("test_bubble_sort", 600_000, iteration_id='1_2_1') # 500μs
16451644
optimized_invocation2 = self.create_test_invocation("test_bubble_sort", 400_000, iteration_id='1_2_1') # 300μs
16461645
original_invocation3 = self.create_test_invocation("test_bubble_sort", 700_000, iteration_id='1_2_2') # 500μs
@@ -1680,7 +1679,7 @@ def test_runtime_comment_addition_while(self, test_config):
16801679
while i<3:
16811680
b = 3
16821681
b1 = 6
1683-
codeflash_output = bubble_sort([3, 1, 2]) # 500μs -> 300μs (66.7% faster)
1682+
codeflash_output = bubble_sort([3, 1, 2]) # 1.80ms -> 1.20ms (50.0% faster)
16841683
assert codeflash_output == [1, 2, 3]
16851684
i += 1
16861685
d = 5
@@ -1701,7 +1700,6 @@ def test_runtime_comment_addition_while(self, test_config):
17011700
# Add test invocations with different runtimes
17021701
original_invocation1 = self.create_test_invocation("test_bubble_sort", 500_000, iteration_id='1_2_0') # 500μs
17031702
optimized_invocation1 = self.create_test_invocation("test_bubble_sort", 300_000, iteration_id='1_2_0') # 300μs
1704-
# longer runtime than minimum, will not contribute
17051703
original_invocation2 = self.create_test_invocation("test_bubble_sort", 600_000, iteration_id='1_2_1') # 500μs
17061704
optimized_invocation2 = self.create_test_invocation("test_bubble_sort", 400_000, iteration_id='1_2_1') # 300μs
17071705
original_invocation3 = self.create_test_invocation("test_bubble_sort", 700_000, iteration_id='1_2_2') # 500μs
@@ -1741,7 +1739,7 @@ def test_runtime_comment_addition_with(self, test_config):
17411739
with open('a.txt','rb') as f:
17421740
b = 3
17431741
b1 = 6
1744-
codeflash_output = bubble_sort([3, 1, 2]) # 500μs -> 300μs (66.7% faster)
1742+
codeflash_output = bubble_sort([3, 1, 2]) # 1.80ms -> 1.20ms (50.0% faster)
17451743
assert codeflash_output == [1, 2, 5]
17461744
i += 1
17471745
d = 5
@@ -1762,7 +1760,6 @@ def test_runtime_comment_addition_with(self, test_config):
17621760
# Add test invocations with different runtimes
17631761
original_invocation1 = self.create_test_invocation("test_bubble_sort", 500_000, iteration_id='1_2_0') # 500μs
17641762
optimized_invocation1 = self.create_test_invocation("test_bubble_sort", 300_000, iteration_id='1_2_0') # 300μs
1765-
# longer runtime than minimum, will not contribute
17661763
original_invocation2 = self.create_test_invocation("test_bubble_sort", 600_000, iteration_id='1_2_1') # 500μs
17671764
optimized_invocation2 = self.create_test_invocation("test_bubble_sort", 400_000, iteration_id='1_2_1') # 300μs
17681765
original_invocation3 = self.create_test_invocation("test_bubble_sort", 700_000, iteration_id='1_2_2') # 500μs
@@ -1796,7 +1793,7 @@ def test_runtime_comment_addition_lc(self, test_config):
17961793
"""
17971794
expected = """def test_bubble_sort():
17981795
i = 0
1799-
codeflash_output = [bubble_sort([3, 1, 2]) for _ in range(3)] # 500μs -> 300μs (66.7% faster)
1796+
codeflash_output = [bubble_sort([3, 1, 2]) for _ in range(3)] # 1.80ms -> 1.20ms (50.0% faster)
18001797
assert codeflash_output == [[1,2,3],[1,2,3],[1,2,3]]
18011798
i += 1
18021799
d = 5
@@ -1817,7 +1814,6 @@ def test_runtime_comment_addition_lc(self, test_config):
18171814
# Add test invocations with different runtimes
18181815
original_invocation1 = self.create_test_invocation("test_bubble_sort", 500_000, iteration_id='1_0') # 500μs
18191816
optimized_invocation1 = self.create_test_invocation("test_bubble_sort", 300_000, iteration_id='1_0') # 300μs
1820-
# longer runtime than minimum, will not contribute
18211817
original_invocation2 = self.create_test_invocation("test_bubble_sort", 600_000, iteration_id='1_1') # 500μs
18221818
optimized_invocation2 = self.create_test_invocation("test_bubble_sort", 400_000, iteration_id='1_1') # 300μs
18231819
original_invocation3 = self.create_test_invocation("test_bubble_sort", 700_000, iteration_id='1_2') # 500μs
@@ -1867,7 +1863,7 @@ def test_bubble_sort(input, expected_output):
18671863
)
18681864
def test_bubble_sort(input, expected_output):
18691865
i = 0
1870-
codeflash_output = bubble_sort(input) # 500μs -> 300μs (66.7% faster)
1866+
codeflash_output = bubble_sort(input) # 1.80ms -> 1.20ms (50.0% faster)
18711867
assert codeflash_output == expected_output
18721868
i += 1
18731869
d = 5
@@ -1888,7 +1884,6 @@ def test_bubble_sort(input, expected_output):
18881884
# Add test invocations with different runtimes
18891885
original_invocation1 = self.create_test_invocation("test_bubble_sort", 500_000, iteration_id='1_0') # 500μs
18901886
optimized_invocation1 = self.create_test_invocation("test_bubble_sort", 300_000, iteration_id='1_0') # 300μs
1891-
# longer runtime than minimum, will not contribute
18921887
original_invocation2 = self.create_test_invocation("test_bubble_sort", 600_000, iteration_id='1_1') # 500μs
18931888
optimized_invocation2 = self.create_test_invocation("test_bubble_sort", 400_000, iteration_id='1_1') # 300μs
18941889
original_invocation3 = self.create_test_invocation("test_bubble_sort", 700_000, iteration_id='1_2') # 500μs

0 commit comments

Comments
 (0)