Skip to content

Commit ce421d4

Browse files
committed
Revert "basic async critic - WIP"
This reverts commit a364075.
1 parent 878699a commit ce421d4

File tree

5 files changed

+7
-99
lines changed

5 files changed

+7
-99
lines changed

codeflash/code_utils/codeflash_wrap_decorator.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,6 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
299299
else:
300300
async_wrapper.index[test_id] = 0
301301

302-
# Initialize cumulative throughput tracking
303-
if not hasattr(async_wrapper, "start_time"):
304-
async_wrapper.start_time = time.perf_counter()
305-
if not hasattr(async_wrapper, "total_operations"):
306-
async_wrapper.total_operations = 0
307-
308302
codeflash_test_index = async_wrapper.index[test_id]
309303
invocation_id = f"{line_id}_{codeflash_test_index}"
310304
test_stdout_tag = f"{test_module_name}:{(test_class_name + '.' if test_class_name else '')}{test_name}:{function_name}:{loop_index}:{invocation_id}"
@@ -325,12 +319,7 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: # noqa: ANN401
325319
finally:
326320
gc.enable()
327321

328-
# Update cumulative throughput tracking
329-
async_wrapper.total_operations += 1
330-
elapsed_time = time.perf_counter() - async_wrapper.start_time
331-
throughput = async_wrapper.total_operations / elapsed_time if elapsed_time > 0 else 0
332-
333-
print(f"!######{test_stdout_tag}:{codeflash_duration}:throughput_{throughput:.2f}_ops_per_sec######!")
322+
print(f"!######{test_stdout_tag}:{codeflash_duration}######!")
334323

335324
if exception:
336325
raise exception

codeflash/models/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,6 @@ class FunctionTestInvocation:
552552
timed_out: Optional[bool]
553553
verification_type: Optional[str] = VerificationType.FUNCTION_CALL
554554
stdout: Optional[str] = None
555-
throughput: Optional[float] = None # Operations per second
556555

557556
@property
558557
def unique_invocation_loop_id(self) -> str:

codeflash/optimization/function_optimizer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -566,11 +566,7 @@ def determine_best_candidate(
566566
tree = Tree(f"Candidate #{candidate_index} - Runtime Information")
567567
benchmark_tree = None
568568
if speedup_critic(
569-
candidate_result,
570-
original_code_baseline.runtime,
571-
None,
572-
self.function_to_optimize,
573-
original_baseline_results=original_code_baseline
569+
candidate_result, original_code_baseline.runtime, best_runtime_until_now=None
574570
) and quantity_of_tests_critic(candidate_result):
575571
tree.add("This candidate is faster than the original code. 🚀") # TODO: Change this description
576572
tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}")

codeflash/result/critic.py

Lines changed: 2 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from codeflash.models.models import TestType
1313

1414
if TYPE_CHECKING:
15-
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
16-
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline, TestResults
15+
from codeflash.models.models import CoverageData, OptimizedCandidateResult, OriginalCodeBaseline
1716

1817

1918
def performance_gain(*, original_runtime_ns: int, optimized_runtime_ns: int) -> float:
@@ -30,29 +29,16 @@ def speedup_critic(
3029
candidate_result: OptimizedCandidateResult,
3130
original_code_runtime: int,
3231
best_runtime_until_now: int | None,
33-
function_to_optimize: FunctionToOptimize,
3432
*,
3533
disable_gh_action_noise: bool = False,
36-
original_baseline_results: OriginalCodeBaseline | None = None,
3734
) -> bool:
3835
"""Take in a correct optimized Test Result and decide if the optimization should actually be surfaced to the user.
3936
40-
For async functions, dispatches to async_speedup_critic for specialized evaluation.
41-
For sync functions, uses traditional runtime-only evaluation.
42-
4337
Ensure that the optimization is actually faster than the original code, above the noise floor.
4438
The noise floor is a function of the original code runtime. Currently, the noise floor is 2xMIN_IMPROVEMENT_THRESHOLD
4539
when the original runtime is less than 10 microseconds, and becomes MIN_IMPROVEMENT_THRESHOLD for any higher runtime.
4640
The noise floor is doubled when benchmarking on a (noisy) GitHub Action virtual instance, also we want to be more confident there.
4741
"""
48-
if function_to_optimize.is_async and original_baseline_results:
49-
return async_speedup_critic(
50-
candidate_result=candidate_result,
51-
original_baseline_results=original_baseline_results,
52-
best_runtime_until_now=best_runtime_until_now,
53-
disable_gh_action_noise=disable_gh_action_noise,
54-
)
55-
5642
noise_floor = 3 * MIN_IMPROVEMENT_THRESHOLD if original_code_runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
5743
if not disable_gh_action_noise and env_utils.is_ci():
5844
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
@@ -61,62 +47,11 @@ def speedup_critic(
6147
original_runtime_ns=original_code_runtime, optimized_runtime_ns=candidate_result.best_test_runtime
6248
)
6349
if best_runtime_until_now is None:
50+
# collect all optimizations with this
6451
return bool(perf_gain > noise_floor)
6552
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
6653

6754

68-
def async_speedup_critic(
69-
candidate_result: OptimizedCandidateResult,
70-
original_baseline_results: OriginalCodeBaseline,
71-
best_runtime_until_now: int | None,
72-
*,
73-
disable_gh_action_noise: bool = False,
74-
) -> bool:
75-
"""Simplified speedup evaluation for async functions with throughput-first approach.
76-
77-
For async functions:
78-
1. If throughput data exists and shows improvement, accept the optimization
79-
2. Otherwise, fall back to traditional runtime evaluation
80-
"""
81-
# Calculate noise floor with same logic as sync functions
82-
noise_floor = (
83-
3 * MIN_IMPROVEMENT_THRESHOLD if original_baseline_results.runtime < 10000 else MIN_IMPROVEMENT_THRESHOLD
84-
)
85-
if not disable_gh_action_noise and env_utils.is_ci():
86-
noise_floor = noise_floor * 2 # Increase the noise floor in GitHub Actions mode
87-
88-
# Check for throughput improvement first
89-
candidate_throughput = _calculate_average_throughput(candidate_result.benchmarking_test_results)
90-
original_throughput = _calculate_average_throughput(original_baseline_results.benchmarking_test_results)
91-
92-
if original_throughput and original_throughput > 0 and candidate_throughput:
93-
throughput_gain = (candidate_throughput - original_throughput) / original_throughput
94-
if throughput_gain > noise_floor:
95-
# Throughput improved above noise floor - accept optimization
96-
return True if best_runtime_until_now is None else candidate_result.best_test_runtime < best_runtime_until_now
97-
98-
# Fall back to traditional runtime evaluation
99-
perf_gain = performance_gain(
100-
original_runtime_ns=original_baseline_results.runtime, optimized_runtime_ns=candidate_result.best_test_runtime
101-
)
102-
103-
if best_runtime_until_now is None:
104-
return bool(perf_gain > noise_floor)
105-
return bool(perf_gain > noise_floor and candidate_result.best_test_runtime < best_runtime_until_now)
106-
107-
108-
def _calculate_average_throughput(test_results: TestResults) -> float | None:
109-
"""Calculate average throughput from test results that have throughput data."""
110-
throughput_values = [
111-
result.throughput for result in test_results.test_results if result.throughput is not None and result.did_pass
112-
]
113-
114-
if not throughput_values:
115-
return None
116-
117-
return sum(throughput_values) / len(throughput_values)
118-
119-
12055
def quantity_of_tests_critic(candidate_result: OptimizedCandidateResult | OriginalCodeBaseline) -> bool:
12156
test_results = candidate_result.behavior_test_results
12257
report = test_results.get_test_pass_fail_report_by_type()

codeflash/verification/parse_test_output.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def parse_func(file_path: Path) -> XMLParser:
3737

3838

3939
matches_re_start = re.compile(r"!\$######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######\$!\n")
40-
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)(?::throughput_([\d\.]+)_ops_per_sec)?######!")
40+
matches_re_end = re.compile(r"!######(.*?):(.*?)([^\.:]*?):(.*?):(.*?):(.*?)######!")
4141

4242

4343
def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, test_config: TestConfig) -> TestResults:
@@ -93,7 +93,6 @@ def parse_test_return_values_bin(file_location: Path, test_files: TestFiles, tes
9393
return_value=test_pickle,
9494
timed_out=False,
9595
verification_type=VerificationType.FUNCTION_CALL,
96-
throughput=None,
9796
)
9897
)
9998
except Exception as e:
@@ -161,7 +160,6 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
161160
return_value=ret_val,
162161
timed_out=False,
163162
verification_type=VerificationType(verification_type) if verification_type else None,
164-
throughput=None,
165163
)
166164
)
167165
except Exception:
@@ -295,27 +293,22 @@ def parse_test_xml(
295293
return_value=None,
296294
timed_out=timed_out,
297295
stdout="",
298-
throughput=None,
299296
)
300297
)
301298

302299
else:
303300
for match_index, match in enumerate(begin_matches):
304301
groups = match.groups()
305302
end_match = end_matches.get(groups)
306-
iteration_id, runtime, throughput = groups[5], None, None
303+
iteration_id, runtime = groups[5], None
307304
if end_match:
308305
stdout = sys_stdout[match.end() : end_match.start()]
309-
end_groups = end_match.groups()
310-
split_val = end_groups[5].split(":")
306+
split_val = end_match.groups()[5].split(":")
311307
if len(split_val) > 1:
312308
iteration_id = split_val[0]
313309
runtime = int(split_val[1])
314310
else:
315311
iteration_id, runtime = split_val[0], None
316-
# Extract throughput if present (group 6 is the throughput capture group)
317-
if len(end_groups) > 6 and end_groups[6] is not None:
318-
throughput = float(end_groups[6])
319312
elif match_index == len(begin_matches) - 1:
320313
stdout = sys_stdout[match.end() :]
321314
else:
@@ -339,7 +332,6 @@ def parse_test_xml(
339332
return_value=None,
340333
timed_out=timed_out,
341334
stdout=stdout,
342-
throughput=throughput,
343335
)
344336
)
345337

@@ -428,7 +420,6 @@ def merge_test_results(
428420
if result_bin.verification_type
429421
else None,
430422
stdout=xml_result.stdout,
431-
throughput=None,
432423
)
433424
)
434425
elif xml_results.test_results[0].id.iteration_id is not None:
@@ -459,7 +450,6 @@ def merge_test_results(
459450
if bin_result.verification_type
460451
else None,
461452
stdout=xml_result.stdout,
462-
throughput=None,
463453
)
464454
)
465455
else:
@@ -487,7 +477,6 @@ def merge_test_results(
487477
if bin_result.verification_type
488478
else None,
489479
stdout=xml_result.stdout,
490-
throughput=None,
491480
)
492481
)
493482

0 commit comments

Comments
 (0)