Skip to content

Commit a6072b9

Browse files
committed
Merge branch 'granular-async-instrumentation' of https://github.com/codeflash-ai/codeflash into async-support-for
2 parents 52a5a0c + ce421d4 commit a6072b9

File tree

5 files changed

+7
-103
lines changed

5 files changed

+7
-103
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,6 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
299299
else:
300300
async_wrapper.index[test_id] = 0
301301

302-
# Initialize cumulative throughput tracking
303-
if not hasattr(async_wrapper, "start_time"):
304-
async_wrapper.start_time = time.perf_counter()
305-
if not hasattr(async_wrapper, "total_operations"):
306-
async_wrapper.total_operations = 0
307-
308302
codeflash_test_index = async_wrapper.index[test_id]
309303
invocation_id = f"{line_id}_{codeflash_test_index}"
310304
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
@@ -325,12 +319,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
325319
finally:
326320
gc.enable()
327321

328-
# Update cumulative throughput tracking
329-
async_wrapper.total_operations += 1
330-
elapsed_time = time.perf_counter() - async_wrapper.start_time
331-
throughput = async_wrapper.total_operations / elapsed_time if elapsed_time > 0 else 0
332-
333-
print(f"!######{test_stdout_tag}:{codeflash_duration}:throughput_{throughput:.2f}_ops_per_sec######!")
322+
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
334323

335324
if exception:
336325
raise exception

codeflash/models/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,6 @@ class FunctionTestInvocation:
552552
timed_out: Optional[bool]
553553
verification_type: Optional[str] = VerificationType.FUNCTION_CALL
554554
stdout: Optional[str] = None
555-
throughput: Optional[float] = None # Operations per second
556555

557556
@property
558557
def unique_invocation_loop_id(self) -> str:

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -566,11 +566,7 @@ def determine_best_candidate(
566566
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
567567
benchmark_tree = None
568568
if speedup_critic(
569-
candidate_result,
570-
original_code_baseline.runtime,
571-
None,
572-
self.function_to_optimize,
573-
original_baseline_results=original_code_baseline,
569+
candidate_result, original_code_baseline.runtime, best_runtime_until_now=None
574570
) and quantity_of_tests_critic(candidate_result):
575571
tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description
576572
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")

codeflash/result/critic.py

Lines changed: 2 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from codeflash.models.models import TestType
1313

1414
if TYPE_CHECKING:
15-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
16-
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline, TestResults
15+
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
1716

1817

1918
def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float:
@@ -30,29 +29,16 @@ def speedup_critic(
3029
candidate_result: OptimizedCandidateResult,
3130
original_code_runtime: int,
3231
best_runtime_until_now: int | None,
33-
function_to_optimize: FunctionToOptimize,
3432
*,
3533
disable_gh_action_noise: bool = False,
36-
original_baseline_results: OriginalCodeBaseline | None = None,
3734
) -> bool:
3835
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
3936
40-
For async functions, dispatches to async_speedup_critic for specialized evaluation.
41-
For sync functions, uses traditional runtime-only evaluation.
42-
4337
Ensure that the optimization is actually faster than the original code, above the noise floor.
4438
The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
4539
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
4640
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
4741
"""
48-
if function_to_optimize.is_async and original_baseline_results:
49-
return async_speedup_critic(
50-
candidate_result=candidate_result,
51-
original_baseline_results=original_baseline_results,
52-
best_runtime_until_now=best_runtime_until_now,
53-
disable_gh_action_noise=disable_gh_action_noise,
54-
)
55-
5642
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
5743
if not disable_gh_action_noise and env_utils.is_ci():
5844
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
@@ -61,64 +47,11 @@ def speedup_critic(
6147
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
6248
)
6349
if best_runtime_until_now is None:
50+
# collect all optimizations with this
6451
return bool(perf_gain > noise_floor)
6552
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
6653

6754

68-
def async_speedup_critic(
69-
candidate_result: OptimizedCandidateResult,
70-
original_baseline_results: OriginalCodeBaseline,
71-
best_runtime_until_now: int | None,
72-
*,
73-
disable_gh_action_noise: bool = False,
74-
) -> bool:
75-
"""Simplified speedup evaluation for async functions with throughput-first approach.
76-
77-
For async functions:
78-
1. If throughput data exists and shows improvement, accept the optimization
79-
2. Otherwise, fall back to traditional runtime evaluation
80-
"""
81-
# Calculate noise floor with same logic as sync functions
82-
noise_floor = (
83-
3 * MIN_IMPROVEMENT_THRESHOLD if original_baseline_results.runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
84-
)
85-
if not disable_gh_action_noise and env_utils.is_ci():
86-
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
87-
88-
# Check for throughput improvement first
89-
candidate_throughput = _calculate_average_throughput(candidate_result.benchmarking_test_results)
90-
original_throughput = _calculate_average_throughput(original_baseline_results.benchmarking_test_results)
91-
92-
if original_throughput and original_throughput > 0 and candidate_throughput:
93-
throughput_gain = (candidate_throughput - original_throughput) / original_throughput
94-
if throughput_gain > noise_floor:
95-
# Throughput improved above noise floor - accept optimization
96-
return (
97-
True if best_runtime_until_now is None else candidate_result.best_test_runtime < best_runtime_until_now
98-
)
99-
100-
# Fall back to traditional runtime evaluation
101-
perf_gain = performance_gain(
102-
original_runtime_ns=original_baseline_results.runtime, optimized_runtime_ns=candidate_result.best_test_runtime
103-
)
104-
105-
if best_runtime_until_now is None:
106-
return bool(perf_gain > noise_floor)
107-
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
108-
109-
110-
def _calculate_average_throughput(test_results: TestResults) -> float | None:
111-
"""Calculate average throughput from test results that have throughput data."""
112-
throughput_values = [
113-
result.throughput for result in test_results.test_results if result.throughput is not None and result.did_pass
114-
]
115-
116-
if not throughput_values:
117-
return None
118-
119-
return sum(throughput_values) / len(throughput_values)
120-
121-
12255
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
12356
test_results = candidate_result.behavior_test_results
12457
report = test_results.get_test_pass_fail_report_by_type()

codeflash/verification/parse_test_output.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,7 @@ def parse_func(file_path: Path) -> XMLParser:
3737

3838

3939
matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######\$!\n")
40-
matches_re_end = re.compile(
41-
r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)(?::throughput_([\d\.]+)_ops_per_sec)?######!"
42-
)
40+
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
4341

4442

4543
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
@@ -95,7 +93,6 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes
9593
return_value=test_pickle,
9694
timed_out=False,
9795
verification_type=VerificationType.FUNCTION_CALL,
98-
throughput=None,
9996
)
10097
)
10198
except Exception as e:
@@ -163,7 +160,6 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
163160
return_value=ret_val,
164161
timed_out=False,
165162
verification_type=VerificationType(verification_type) if verification_type else None,
166-
throughput=None,
167163
)
168164
)
169165
except Exception:
@@ -297,27 +293,22 @@ def parse_test_xml(
297293
return_value=None,
298294
timed_out=timed_out,
299295
stdout="",
300-
throughput=None,
301296
)
302297
)
303298

304299
else:
305300
for match_index, match in enumerate(begin_matches):
306301
groups = match.groups()
307302
end_match = end_matches.get(groups)
308-
iteration_id, runtime, throughput = groups[5], None, None
303+
iteration_id, runtime = groups[5], None
309304
if end_match:
310305
stdout = sys_stdout[match.end() : end_match.start()]
311-
end_groups = end_match.groups()
312-
split_val = end_groups[5].split(":")
306+
split_val = end_match.groups()[5].split(":")
313307
if len(split_val) > 1:
314308
iteration_id = split_val[0]
315309
runtime = int(split_val[1])
316310
else:
317311
iteration_id, runtime = split_val[0], None
318-
# Extract throughput if present (group 6 is the throughput capture group)
319-
if len(end_groups) > 6 and end_groups[6] is not None:
320-
throughput = float(end_groups[6])
321312
elif match_index == len(begin_matches) - 1:
322313
stdout = sys_stdout[match.end() :]
323314
else:
@@ -341,7 +332,6 @@ def parse_test_xml(
341332
return_value=None,
342333
timed_out=timed_out,
343334
stdout=stdout,
344-
throughput=throughput,
345335
)
346336
)
347337

@@ -430,7 +420,6 @@ def merge_test_results(
430420
if result_bin.verification_type
431421
else None,
432422
stdout=xml_result.stdout,
433-
throughput=None,
434423
)
435424
)
436425
elif xml_results.test_results[0].id.iteration_id is not None:
@@ -461,7 +450,6 @@ def merge_test_results(
461450
if bin_result.verification_type
462451
else None,
463452
stdout=xml_result.stdout,
464-
throughput=None,
465453
)
466454
)
467455
else:
@@ -489,7 +477,6 @@ def merge_test_results(
489477
if bin_result.verification_type
490478
else None,
491479
stdout=xml_result.stdout,
492-
throughput=None,
493480
)
494481
)
495482

0 commit comments

Comments
 (0)