Skip to content

Commit 0b3be3f

Browse files
some enhancements from claude pr review
1 parent 83dff02 commit 0b3be3f

File tree

2 files changed

+25
-25
lines changed

2 files changed

+25
-25
lines changed

codeflash/code_utils/config_consts.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,19 @@
88
MAX_CUMULATIVE_TEST_RUNTIME_NANOSECONDS = 100e6 # 100ms
99
N_TESTS_TO_GENERATE = 2
1010
TOTAL_LOOPING_TIME = 10.0 # 10 second candidate benchmarking budget
11-
CONSISTENT_LOOP_COUNT = 3
1211
COVERAGE_THRESHOLD = 60.0
1312
MIN_TESTCASE_PASSED_THRESHOLD = 6
1413
REPEAT_OPTIMIZATION_PROBABILITY = 0.1
1514
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
1615
N_CANDIDATES_LP = 6
1716

17+
# pytest loop stability
18+
STABILITY_WARMUP_LOOPS = 4
19+
STABILITY_WINDOW_SIZE = 6
20+
STABILITY_CENTER_TOLERANCE = 0.01 # ±1% around median
21+
STABILITY_SPREAD_TOLERANCE = 0.02 # 2% window spread
22+
STABILITY_SLOPE_TOLERANCE = 0.01 # 1% improvement allowed
23+
1824
# Refinement
1925
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
2026
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2

codeflash/verification/pytest_plugin.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@
2020
import pytest
2121
from pluggy import HookspecMarker
2222

23+
from codeflash.code_utils.config_consts import (
24+
STABILITY_CENTER_TOLERANCE,
25+
STABILITY_SLOPE_TOLERANCE,
26+
STABILITY_SPREAD_TOLERANCE,
27+
STABILITY_WARMUP_LOOPS,
28+
STABILITY_WINDOW_SIZE,
29+
)
2330
from codeflash.result.best_summed_runtime import calculate_best_summed_runtime
2431

2532
if TYPE_CHECKING:
@@ -287,13 +294,16 @@ def get_runtime_from_stdout(stdout: str) -> Optional[int]:
287294
return int(payload[last_colon + 1 :])
288295

289296

297+
_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$")
298+
299+
290300
def should_stop(
291301
runtimes: list[int],
292-
warmup: int = 4,
293-
window: int = 6,
294-
center_rel_tol: float = 0.01, # ±1% around median
295-
spread_rel_tol: float = 0.02, # 2% window spread
296-
slope_rel_tol: float = 0.01, # 1% improvement allowed
302+
warmup: int = STABILITY_WARMUP_LOOPS,
303+
window: int = STABILITY_WINDOW_SIZE,
304+
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
305+
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
306+
slope_rel_tol: float = STABILITY_SLOPE_TOLERANCE,
297307
) -> bool:
298308
if len(runtimes) < warmup + window:
299309
return False
@@ -328,29 +338,12 @@ def __init__(self, config: Config) -> None:
328338
self.logger = logging.getLogger(self.name)
329339
self.usable_runtime_data_by_test_case: dict[str, list[int]] = {}
330340

331-
def dynamic_tolerance(self, avg_ns: float) -> float: # noqa: PLR0911
332-
if avg_ns < 200_000: # < 0.2 ms
333-
return 0.80 # 80%
334-
if avg_ns < 500_000: # < 0.5 ms
335-
return 0.60 # 60%
336-
if avg_ns < 1_000_000: # < 1 ms
337-
return 0.45 # 45%
338-
if avg_ns < 2_000_000: # < 2 ms
339-
return 0.30 # 30%
340-
if avg_ns < 5_000_000: # < 5 ms
341-
return 0.20 # 20%
342-
if avg_ns < 20_000_000: # < 20 ms
343-
return 0.12 # 12%
344-
if avg_ns < 100_000_000: # < 100 ms
345-
return 0.07 # 7%
346-
return 0.05 # ≥ 100 ms
347-
348341
@pytest.hookimpl
349342
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
350343
if report.when == "call" and report.passed:
351344
duration_ns = get_runtime_from_stdout(report.capstdout)
352345
if duration_ns:
353-
clean_id = re.sub(r"\s*\[\s*\d+\s*\]\s*$", "", report.nodeid)
346+
clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid)
354347
self.usable_runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns)
355348

356349
@hookspec(firstresult=True)
@@ -388,7 +381,8 @@ def pytest_runtestloop(self, session: Session) -> bool:
388381
raise session.Interrupted(session.shouldstop)
389382

390383
best_runtime_until_now = calculate_best_summed_runtime(self.usable_runtime_data_by_test_case)
391-
runtimes.append(best_runtime_until_now)
384+
if best_runtime_until_now > 0:
385+
runtimes.append(best_runtime_until_now)
392386

393387
if should_stop(runtimes):
394388
break

0 commit comments

Comments
 (0)