1212import time as _time_module
1313import warnings
1414from pathlib import Path
15- from typing import TYPE_CHECKING , Any , Callable
15+ from typing import TYPE_CHECKING , Any , Callable , Optional
1616from unittest import TestCase
1717
1818# PyTest Imports
1919import pytest
2020from pluggy import HookspecMarker
2121
22+ from codeflash .code_utils .config_consts import (
23+ STABILITY_CENTER_TOLERANCE ,
24+ STABILITY_SPREAD_TOLERANCE ,
25+ STABILITY_WINDOW_SIZE ,
26+ )
27+
2228if TYPE_CHECKING :
2329 from _pytest .config import Config , Parser
2430 from _pytest .main import Session
@@ -77,6 +83,7 @@ class UnexpectedError(Exception):
7783# Store references to original functions before any patching
7884_ORIGINAL_TIME_TIME = _time_module .time
7985_ORIGINAL_PERF_COUNTER = _time_module .perf_counter
86+ _ORIGINAL_PERF_COUNTER_NS = _time_module .perf_counter_ns
8087_ORIGINAL_TIME_SLEEP = _time_module .sleep
8188
8289
@@ -249,6 +256,14 @@ def pytest_addoption(parser: Parser) -> None:
249256 choices = ("function" , "class" , "module" , "session" ),
250257 help = "Scope for looping tests" ,
251258 )
259+ pytest_loops .addoption (
260+ "--codeflash_stability_check" ,
261+ action = "store" ,
262+ default = "false" ,
263+ type = str ,
264+ choices = ("true" , "false" ),
265+ help = "Enable stability checks for the loops" ,
266+ )
252267
253268
254269@pytest .hookimpl (trylast = True )
@@ -260,6 +275,70 @@ def pytest_configure(config: Config) -> None:
260275 _apply_deterministic_patches ()
261276
262277
278+ def get_runtime_from_stdout (stdout : str ) -> Optional [int ]:
279+ marker_start = "!######"
280+ marker_end = "######!"
281+
282+ if not stdout :
283+ return None
284+
285+ end = stdout .rfind (marker_end )
286+ if end == - 1 :
287+ return None
288+
289+ start = stdout .rfind (marker_start , 0 , end )
290+ if start == - 1 :
291+ return None
292+
293+ payload = stdout [start + len (marker_start ) : end ]
294+ last_colon = payload .rfind (":" )
295+ if last_colon == - 1 :
296+ return None
297+ try :
298+ return int (payload [last_colon + 1 :])
299+ except ValueError :
300+ return None
301+
302+
303+ _NODEID_BRACKET_PATTERN = re .compile (r"\s*\[\s*\d+\s*\]\s*$" )
304+
305+
306+ def should_stop (
307+ runtimes : list [int ],
308+ window : int ,
309+ min_window_size : int ,
310+ center_rel_tol : float = STABILITY_CENTER_TOLERANCE ,
311+ spread_rel_tol : float = STABILITY_SPREAD_TOLERANCE ,
312+ ) -> bool :
313+ if len (runtimes ) < window :
314+ return False
315+
316+ if len (runtimes ) < min_window_size :
317+ return False
318+
319+ recent = runtimes [- window :]
320+
321+ # Use sorted array for faster median and min/max operations
322+ recent_sorted = sorted (recent )
323+ mid = window // 2
324+ m = recent_sorted [mid ] if window % 2 else (recent_sorted [mid - 1 ] + recent_sorted [mid ]) / 2
325+
326+ # 1) All recent points close to the median
327+ centered = True
328+ for r in recent :
329+ if abs (r - m ) / m > center_rel_tol :
330+ centered = False
331+ break
332+
333+ # 2) Window spread is small
334+ r_min , r_max = recent_sorted [0 ], recent_sorted [- 1 ]
335+ if r_min == 0 :
336+ return False
337+ spread_ok = (r_max - r_min ) / r_min <= spread_rel_tol
338+
339+ return centered and spread_ok
340+
341+
263342class PytestLoops :
264343 name : str = "pytest-loops"
265344
@@ -268,6 +347,20 @@ def __init__(self, config: Config) -> None:
268347 level = logging .DEBUG if config .option .verbose > 1 else logging .INFO
269348 logging .basicConfig (level = level )
270349 self .logger = logging .getLogger (self .name )
350+ self .runtime_data_by_test_case : dict [str , list [int ]] = {}
351+ self .enable_stability_check : bool = (
352+ str (getattr (config .option , "codeflash_stability_check" , "false" )).lower () == "true"
353+ )
354+
355+ @pytest .hookimpl
356+ def pytest_runtest_logreport (self , report : pytest .TestReport ) -> None :
357+ if not self .enable_stability_check :
358+ return
359+ if report .when == "call" and report .passed :
360+ duration_ns = get_runtime_from_stdout (report .capstdout )
361+ if duration_ns :
362+ clean_id = _NODEID_BRACKET_PATTERN .sub ("" , report .nodeid )
363+ self .runtime_data_by_test_case .setdefault (clean_id , []).append (duration_ns )
271364
272365 @hookspec (firstresult = True )
273366 def pytest_runtestloop (self , session : Session ) -> bool :
@@ -283,11 +376,12 @@ def pytest_runtestloop(self, session: Session) -> bool:
283376 total_time : float = self ._get_total_time (session )
284377
285378 count : int = 0
379+ runtimes = []
380+ elapsed_ns = 0
286381
287382 while total_time >= SHORTEST_AMOUNT_OF_TIME : # need to run at least one for normal tests
288383 count += 1
289- total_time = self ._get_total_time (session )
290-
384+ loop_start = _ORIGINAL_PERF_COUNTER_NS ()
291385 for index , item in enumerate (session .items ):
292386 item : pytest .Item = item # noqa: PLW0127, PLW2901
293387 item ._report_sections .clear () # clear reports for new test # noqa: SLF001
@@ -304,8 +398,26 @@ def pytest_runtestloop(self, session: Session) -> bool:
304398 raise session .Failed (session .shouldfail )
305399 if session .shouldstop :
306400 raise session .Interrupted (session .shouldstop )
401+
402+ if self .enable_stability_check :
403+ elapsed_ns += _ORIGINAL_PERF_COUNTER_NS () - loop_start
404+ best_runtime_until_now = sum ([min (data ) for data in self .runtime_data_by_test_case .values ()])
405+ if best_runtime_until_now > 0 :
406+ runtimes .append (best_runtime_until_now )
407+
408+ estimated_total_loops = 0
409+ if elapsed_ns > 0 :
410+ rate = count / elapsed_ns
411+ total_time_ns = total_time * 1e9
412+ estimated_total_loops = int (rate * total_time_ns )
413+
414+ window_size = int (STABILITY_WINDOW_SIZE * estimated_total_loops + 0.5 )
415+ if should_stop (runtimes , window_size , session .config .option .codeflash_min_loops ):
416+ break
417+
307418 if self ._timed_out (session , start_time , count ):
308- break # exit loop
419+ break
420+
309421 _ORIGINAL_TIME_SLEEP (self ._get_delay_time (session ))
310422 return True
311423
0 commit comments