Skip to content

Commit a364075

Browse files
committed
basic async critic - WIP
1 parent 2429801 commit a364075

File tree

5 files changed

+99
-7
lines changed

5 files changed

+99
-7
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,12 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
299299
else:
300300
async_wrapper.index[test_id] = 0
301301

302+
# Initialize cumulative throughput tracking
303+
if not hasattr(async_wrapper, "start_time"):
304+
async_wrapper.start_time = time.perf_counter()
305+
if not hasattr(async_wrapper, "total_operations"):
306+
async_wrapper.total_operations = 0
307+
302308
codeflash_test_index = async_wrapper.index[test_id]
303309
invocation_id = f"{line_id}_{codeflash_test_index}"
304310
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
@@ -319,7 +325,12 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
319325
finally:
320326
gc.enable()
321327

322-
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
328+
# Update cumulative throughput tracking
329+
async_wrapper.total_operations += 1
330+
elapsed_time = time.perf_counter() - async_wrapper.start_time
331+
throughput = async_wrapper.total_operations / elapsed_time if elapsed_time > 0 else 0
332+
333+
print(f"!######{test_stdout_tag}:{codeflash_duration}:throughput_{throughput:.2f}_ops_per_sec######!")
323334

324335
if exception:
325336
raise exception

codeflash/models/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,7 @@ class FunctionTestInvocation:
552552
timed_out: Optional[bool]
553553
verification_type: Optional[str] = VerificationType.FUNCTION_CALL
554554
stdout: Optional[str] = None
555+
throughput: Optional[float] = None # Operations per second
555556

556557
@property
557558
def unique_invocation_loop_id(self) -> str:

codeflash/optimization/function_optimizer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,11 @@ def determine_best_candidate(
566566
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
567567
benchmark_tree = None
568568
if speedup_critic(
569-
candidate_result, original_code_baseline.runtime, best_runtime_until_now=None
569+
candidate_result,
570+
original_code_baseline.runtime,
571+
None,
572+
self.function_to_optimize,
573+
original_baseline_results=original_code_baseline
570574
) and quantity_of_tests_critic(candidate_result):
571575
tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description
572576
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")

codeflash/result/critic.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from codeflash.models.models import TestType
1313

1414
if TYPE_CHECKING:
15-
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
15+
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
16+
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline, TestResults
1617

1718

1819
def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float:
@@ -29,16 +30,29 @@ def speedup_critic(
2930
candidate_result: OptimizedCandidateResult,
3031
original_code_runtime: int,
3132
best_runtime_until_now: int | None,
33+
function_to_optimize: FunctionToOptimize,
3234
*,
3335
disable_gh_action_noise: bool = False,
36+
original_baseline_results: OriginalCodeBaseline | None = None,
3437
) -> bool:
3538
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
3639
40+
For async functions, dispatches to async_speedup_critic for specialized evaluation.
41+
For sync functions, uses traditional runtime-only evaluation.
42+
3743
Ensure that the optimization is actually faster than the original code, above the noise floor.
3844
The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
3945
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
4046
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
4147
"""
48+
if function_to_optimize.is_async and original_baseline_results:
49+
return async_speedup_critic(
50+
candidate_result=candidate_result,
51+
original_baseline_results=original_baseline_results,
52+
best_runtime_until_now=best_runtime_until_now,
53+
disable_gh_action_noise=disable_gh_action_noise,
54+
)
55+
4256
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
4357
if not disable_gh_action_noise and env_utils.is_ci():
4458
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
@@ -47,11 +61,62 @@ def speedup_critic(
4761
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
4862
)
4963
if best_runtime_until_now is None:
50-
# collect all optimizations with this
5164
return bool(perf_gain > noise_floor)
5265
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
5366

5467

68+
def async_speedup_critic(
69+
candidate_result: OptimizedCandidateResult,
70+
original_baseline_results: OriginalCodeBaseline,
71+
best_runtime_until_now: int | None,
72+
*,
73+
disable_gh_action_noise: bool = False,
74+
) -> bool:
75+
"""Simplified speedup evaluation for async functions with throughput-first approach.
76+
77+
For async functions:
78+
1. If throughput data exists and shows improvement, accept the optimization
79+
2. Otherwise, fall back to traditional runtime evaluation
80+
"""
81+
# Calculate noise floor with same logic as sync functions
82+
noise_floor = (
83+
3 * MIN_IMPROVEMENT_THRESHOLD if original_baseline_results.runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
84+
)
85+
if not disable_gh_action_noise and env_utils.is_ci():
86+
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
87+
88+
# Check for throughput improvement first
89+
candidate_throughput = _calculate_average_throughput(candidate_result.benchmarking_test_results)
90+
original_throughput = _calculate_average_throughput(original_baseline_results.benchmarking_test_results)
91+
92+
if original_throughput and original_throughput > 0 and candidate_throughput:
93+
throughput_gain = (candidate_throughput - original_throughput) / original_throughput
94+
if throughput_gain > noise_floor:
95+
# Throughput improved above noise floor - accept optimization
96+
return True if best_runtime_until_now is None else candidate_result.best_test_runtime < best_runtime_until_now
97+
98+
# Fall back to traditional runtime evaluation
99+
perf_gain = performance_gain(
100+
original_runtime_ns=original_baseline_results.runtime, optimized_runtime_ns=candidate_result.best_test_runtime
101+
)
102+
103+
if best_runtime_until_now is None:
104+
return bool(perf_gain > noise_floor)
105+
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
106+
107+
108+
def _calculate_average_throughput(test_results: TestResults) -> float | None:
109+
"""Calculate average throughput from test results that have throughput data."""
110+
throughput_values = [
111+
result.throughput for result in test_results.test_results if result.throughput is not None and result.did_pass
112+
]
113+
114+
if not throughput_values:
115+
return None
116+
117+
return sum(throughput_values) / len(throughput_values)
118+
119+
55120
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
56121
test_results = candidate_result.behavior_test_results
57122
report = test_results.get_test_pass_fail_report_by_type()

codeflash/verification/parse_test_output.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def parse_func(file_path: Path) -> XMLParser:
3737

3838

3939
matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######\$!\n")
40-
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
40+
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)(?::throughput_([\d\.]+)_ops_per_sec)?######!")
4141

4242

4343
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
@@ -93,6 +93,7 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes
9393
return_value=test_pickle,
9494
timed_out=False,
9595
verification_type=VerificationType.FUNCTION_CALL,
96+
throughput=None,
9697
)
9798
)
9899
except Exception as e:
@@ -160,6 +161,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
160161
return_value=ret_val,
161162
timed_out=False,
162163
verification_type=VerificationType(verification_type) if verification_type else None,
164+
throughput=None,
163165
)
164166
)
165167
except Exception:
@@ -293,22 +295,27 @@ def parse_test_xml(
293295
return_value=None,
294296
timed_out=timed_out,
295297
stdout="",
298+
throughput=None,
296299
)
297300
)
298301

299302
else:
300303
for match_index, match in enumerate(begin_matches):
301304
groups = match.groups()
302305
end_match = end_matches.get(groups)
303-
iteration_id, runtime = groups[5], None
306+
iteration_id, runtime, throughput = groups[5], None, None
304307
if end_match:
305308
stdout = sys_stdout[match.end() : end_match.start()]
306-
split_val = end_match.groups()[5].split(":")
309+
end_groups = end_match.groups()
310+
split_val = end_groups[5].split(":")
307311
if len(split_val) > 1:
308312
iteration_id = split_val[0]
309313
runtime = int(split_val[1])
310314
else:
311315
iteration_id, runtime = split_val[0], None
316+
# Extract throughput if present (group 6 is the throughput capture group)
317+
if len(end_groups) > 6 and end_groups[6] is not None:
318+
throughput = float(end_groups[6])
312319
elif match_index == len(begin_matches) - 1:
313320
stdout = sys_stdout[match.end() :]
314321
else:
@@ -332,6 +339,7 @@ def parse_test_xml(
332339
return_value=None,
333340
timed_out=timed_out,
334341
stdout=stdout,
342+
throughput=throughput,
335343
)
336344
)
337345

@@ -420,6 +428,7 @@ def merge_test_results(
420428
if result_bin.verification_type
421429
else None,
422430
stdout=xml_result.stdout,
431+
throughput=None,
423432
)
424433
)
425434
elif xml_results.test_results[0].id.iteration_id is not None:
@@ -450,6 +459,7 @@ def merge_test_results(
450459
if bin_result.verification_type
451460
else None,
452461
stdout=xml_result.stdout,
462+
throughput=None,
453463
)
454464
)
455465
else:
@@ -477,6 +487,7 @@ def merge_test_results(
477487
if bin_result.verification_type
478488
else None,
479489
stdout=xml_result.stdout,
490+
throughput=None,
480491
)
481492
)
482493

0 commit comments

Comments
 (0)