|
20 | 20 | import pytest |
21 | 21 | from pluggy import HookspecMarker |
22 | 22 |
|
| 23 | +from codeflash.code_utils.config_consts import ( |
| 24 | + STABILITY_CENTER_TOLERANCE, |
| 25 | + STABILITY_SLOPE_TOLERANCE, |
| 26 | + STABILITY_SPREAD_TOLERANCE, |
| 27 | + STABILITY_WARMUP_LOOPS, |
| 28 | + STABILITY_WINDOW_SIZE, |
| 29 | +) |
23 | 30 | from codeflash.result.best_summed_runtime import calculate_best_summed_runtime |
24 | 31 |
|
25 | 32 | if TYPE_CHECKING: |
@@ -287,13 +294,16 @@ def get_runtime_from_stdout(stdout: str) -> Optional[int]: |
287 | 294 | return int(payload[last_colon + 1 :]) |
288 | 295 |
|
289 | 296 |
|
| 297 | +_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$") |
| 298 | + |
| 299 | + |
290 | 300 | def should_stop( |
291 | 301 | runtimes: list[int], |
292 | | - warmup: int = 4, |
293 | | - window: int = 6, |
294 | | - center_rel_tol: float = 0.01, # ±1% around median |
295 | | - spread_rel_tol: float = 0.02, # 2% window spread |
296 | | - slope_rel_tol: float = 0.01, # 1% improvement allowed |
| 302 | + warmup: int = STABILITY_WARMUP_LOOPS, |
| 303 | + window: int = STABILITY_WINDOW_SIZE, |
| 304 | + center_rel_tol: float = STABILITY_CENTER_TOLERANCE, |
| 305 | + spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE, |
| 306 | + slope_rel_tol: float = STABILITY_SLOPE_TOLERANCE, |
297 | 307 | ) -> bool: |
298 | 308 | if len(runtimes) < warmup + window: |
299 | 309 | return False |
@@ -328,29 +338,12 @@ def __init__(self, config: Config) -> None: |
328 | 338 | self.logger = logging.getLogger(self.name) |
329 | 339 | self.usable_runtime_data_by_test_case: dict[str, list[int]] = {} |
330 | 340 |
|
331 | | - def dynamic_tolerance(self, avg_ns: float) -> float: # noqa: PLR0911 |
332 | | - if avg_ns < 200_000: # < 0.2 ms |
333 | | - return 0.80 # 80% |
334 | | - if avg_ns < 500_000: # < 0.5 ms |
335 | | - return 0.60 # 60% |
336 | | - if avg_ns < 1_000_000: # < 1 ms |
337 | | - return 0.45 # 45% |
338 | | - if avg_ns < 2_000_000: # < 2 ms |
339 | | - return 0.30 # 30% |
340 | | - if avg_ns < 5_000_000: # < 5 ms |
341 | | - return 0.20 # 20% |
342 | | - if avg_ns < 20_000_000: # < 20 ms |
343 | | - return 0.12 # 12% |
344 | | - if avg_ns < 100_000_000: # < 100 ms |
345 | | - return 0.07 # 7% |
346 | | - return 0.05 # ≥ 100 ms |
347 | | - |
348 | 341 | @pytest.hookimpl |
349 | 342 | def pytest_runtest_logreport(self, report: pytest.TestReport) -> None: |
350 | 343 | if report.when == "call" and report.passed: |
351 | 344 | duration_ns = get_runtime_from_stdout(report.capstdout) |
352 | 345 | if duration_ns: |
353 | | - clean_id = re.sub(r"\s*\[\s*\d+\s*\]\s*$", "", report.nodeid) |
| 346 | + clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid) |
354 | 347 | self.usable_runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns) |
355 | 348 |
|
356 | 349 | @hookspec(firstresult=True) |
@@ -388,7 +381,8 @@ def pytest_runtestloop(self, session: Session) -> bool: |
388 | 381 | raise session.Interrupted(session.shouldstop) |
389 | 382 |
|
390 | 383 | best_runtime_until_now = calculate_best_summed_runtime(self.usable_runtime_data_by_test_case) |
391 | | - runtimes.append(best_runtime_until_now) |
| 384 | + if best_runtime_until_now > 0: |
| 385 | + runtimes.append(best_runtime_until_now) |
392 | 386 |
|
393 | 387 | if should_stop(runtimes): |
394 | 388 | break |
|
0 commit comments