Skip to content

Commit d2d57fe

Browse files
Merge pull request #967 from codeflash-ai/exp/consistent-loop-break
[Enhancement] Stop looping when runtime is stable (CF-934)
2 parents 2f46cc7 + 270af89 commit d2d57fe

File tree

6 files changed

+126
-14
lines changed

6 files changed

+126
-14
lines changed

codeflash/code_utils/config_consts.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
DEFAULT_IMPORTANCE_THRESHOLD = 0.001
1515
N_CANDIDATES_LP = 6
1616

17+
# pytest loop stability
18+
# For now, we use strict thresholds (large windows and low tolerances), since this is still experimental.
19+
STABILITY_WINDOW_SIZE = 0.35 # 35% of total window
20+
STABILITY_CENTER_TOLERANCE = 0.0025 # ±0.25% around median
21+
STABILITY_SPREAD_TOLERANCE = 0.0025 # 0.25% window spread
22+
1723
# Refinement
1824
REFINE_ALL_THRESHOLD = 2 # when valid optimizations count is 2 or less, refine all optimizations
1925
REFINED_CANDIDATE_RANKING_WEIGHTS = (2, 1) # (runtime, diff), runtime is more important than diff by a factor of 2

codeflash/code_utils/env_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
def check_formatter_installed(formatter_cmds: list[str], exit_on_failure: bool = True) -> bool: # noqa
2020
if not formatter_cmds or formatter_cmds[0] == "disabled":
2121
return True
22-
2322
first_cmd = formatter_cmds[0]
2423
cmd_tokens = shlex.split(first_cmd) if isinstance(first_cmd, str) else [first_cmd]
2524

codeflash/code_utils/formatter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,13 @@ def apply_formatter_cmds(
4646
print_status: bool, # noqa
4747
exit_on_failure: bool = True, # noqa
4848
) -> tuple[Path, str, bool]:
49-
should_make_copy = False
50-
file_path = path
51-
52-
if test_dir_str:
53-
should_make_copy = True
54-
file_path = Path(test_dir_str) / "temp.py"
55-
5649
if not path.exists():
5750
msg = f"File {path} does not exist. Cannot apply formatter commands."
5851
raise FileNotFoundError(msg)
5952

60-
if should_make_copy:
53+
file_path = path
54+
if test_dir_str:
55+
file_path = Path(test_dir_str) / "temp.py"
6156
shutil.copy2(path, file_path)
6257

6358
file_token = "$file" # noqa: S105

codeflash/optimization/function_optimizer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1897,7 +1897,6 @@ def establish_original_code_baseline(
18971897
benchmarking_results, self.function_to_optimize.function_name
18981898
)
18991899
logger.debug(f"Original async function throughput: {async_throughput} calls/second")
1900-
console.rule()
19011900

19021901
if self.args.benchmark:
19031902
replay_benchmarking_test_results = benchmarking_results.group_by_benchmarks(

codeflash/verification/pytest_plugin.py

Lines changed: 116 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,19 @@
1212
import time as _time_module
1313
import warnings
1414
from pathlib import Path
15-
from typing import TYPE_CHECKING, Any, Callable
15+
from typing import TYPE_CHECKING, Any, Callable, Optional
1616
from unittest import TestCase
1717

1818
# PyTest Imports
1919
import pytest
2020
from pluggy import HookspecMarker
2121

22+
from codeflash.code_utils.config_consts import (
23+
STABILITY_CENTER_TOLERANCE,
24+
STABILITY_SPREAD_TOLERANCE,
25+
STABILITY_WINDOW_SIZE,
26+
)
27+
2228
if TYPE_CHECKING:
2329
from _pytest.config import Config, Parser
2430
from _pytest.main import Session
@@ -77,6 +83,7 @@ class UnexpectedError(Exception):
7783
# Store references to original functions before any patching
7884
_ORIGINAL_TIME_TIME = _time_module.time
7985
_ORIGINAL_PERF_COUNTER = _time_module.perf_counter
86+
_ORIGINAL_PERF_COUNTER_NS = _time_module.perf_counter_ns
8087
_ORIGINAL_TIME_SLEEP = _time_module.sleep
8188

8289

@@ -249,6 +256,14 @@ def pytest_addoption(parser: Parser) -> None:
249256
choices=("function", "class", "module", "session"),
250257
help="Scope for looping tests",
251258
)
259+
pytest_loops.addoption(
260+
"--codeflash_stability_check",
261+
action="store",
262+
default="false",
263+
type=str,
264+
choices=("true", "false"),
265+
help="Enable stability checks for the loops",
266+
)
252267

253268

254269
@pytest.hookimpl(trylast=True)
@@ -260,6 +275,70 @@ def pytest_configure(config: Config) -> None:
260275
_apply_deterministic_patches()
261276

262277

278+
def get_runtime_from_stdout(stdout: str) -> Optional[int]:
279+
marker_start = "!######"
280+
marker_end = "######!"
281+
282+
if not stdout:
283+
return None
284+
285+
end = stdout.rfind(marker_end)
286+
if end == -1:
287+
return None
288+
289+
start = stdout.rfind(marker_start, 0, end)
290+
if start == -1:
291+
return None
292+
293+
payload = stdout[start + len(marker_start) : end]
294+
last_colon = payload.rfind(":")
295+
if last_colon == -1:
296+
return None
297+
try:
298+
return int(payload[last_colon + 1 :])
299+
except ValueError:
300+
return None
301+
302+
303+
_NODEID_BRACKET_PATTERN = re.compile(r"\s*\[\s*\d+\s*\]\s*$")
304+
305+
306+
def should_stop(
307+
runtimes: list[int],
308+
window: int,
309+
min_window_size: int,
310+
center_rel_tol: float = STABILITY_CENTER_TOLERANCE,
311+
spread_rel_tol: float = STABILITY_SPREAD_TOLERANCE,
312+
) -> bool:
313+
if len(runtimes) < window:
314+
return False
315+
316+
if len(runtimes) < min_window_size:
317+
return False
318+
319+
recent = runtimes[-window:]
320+
321+
# Use sorted array for faster median and min/max operations
322+
recent_sorted = sorted(recent)
323+
mid = window // 2
324+
m = recent_sorted[mid] if window % 2 else (recent_sorted[mid - 1] + recent_sorted[mid]) / 2
325+
326+
# 1) All recent points close to the median
327+
centered = True
328+
for r in recent:
329+
if abs(r - m) / m > center_rel_tol:
330+
centered = False
331+
break
332+
333+
# 2) Window spread is small
334+
r_min, r_max = recent_sorted[0], recent_sorted[-1]
335+
if r_min == 0:
336+
return False
337+
spread_ok = (r_max - r_min) / r_min <= spread_rel_tol
338+
339+
return centered and spread_ok
340+
341+
263342
class PytestLoops:
264343
name: str = "pytest-loops"
265344

@@ -268,6 +347,20 @@ def __init__(self, config: Config) -> None:
268347
level = logging.DEBUG if config.option.verbose > 1 else logging.INFO
269348
logging.basicConfig(level=level)
270349
self.logger = logging.getLogger(self.name)
350+
self.runtime_data_by_test_case: dict[str, list[int]] = {}
351+
self.enable_stability_check: bool = (
352+
str(getattr(config.option, "codeflash_stability_check", "false")).lower() == "true"
353+
)
354+
355+
@pytest.hookimpl
356+
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
357+
if not self.enable_stability_check:
358+
return
359+
if report.when == "call" and report.passed:
360+
duration_ns = get_runtime_from_stdout(report.capstdout)
361+
if duration_ns:
362+
clean_id = _NODEID_BRACKET_PATTERN.sub("", report.nodeid)
363+
self.runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns)
271364

272365
@hookspec(firstresult=True)
273366
def pytest_runtestloop(self, session: Session) -> bool:
@@ -283,11 +376,12 @@ def pytest_runtestloop(self, session: Session) -> bool:
283376
total_time: float = self._get_total_time(session)
284377

285378
count: int = 0
379+
runtimes = []
380+
elapsed_ns = 0
286381

287382
while total_time >= SHORTEST_AMOUNT_OF_TIME: # need to run at least one for normal tests
288383
count += 1
289-
total_time = self._get_total_time(session)
290-
384+
loop_start = _ORIGINAL_PERF_COUNTER_NS()
291385
for index, item in enumerate(session.items):
292386
item: pytest.Item = item # noqa: PLW0127, PLW2901
293387
item._report_sections.clear() # clear reports for new test # noqa: SLF001
@@ -304,8 +398,26 @@ def pytest_runtestloop(self, session: Session) -> bool:
304398
raise session.Failed(session.shouldfail)
305399
if session.shouldstop:
306400
raise session.Interrupted(session.shouldstop)
401+
402+
if self.enable_stability_check:
403+
elapsed_ns += _ORIGINAL_PERF_COUNTER_NS() - loop_start
404+
best_runtime_until_now = sum([min(data) for data in self.runtime_data_by_test_case.values()])
405+
if best_runtime_until_now > 0:
406+
runtimes.append(best_runtime_until_now)
407+
408+
estimated_total_loops = 0
409+
if elapsed_ns > 0:
410+
rate = count / elapsed_ns
411+
total_time_ns = total_time * 1e9
412+
estimated_total_loops = int(rate * total_time_ns)
413+
414+
window_size = int(STABILITY_WINDOW_SIZE * estimated_total_loops + 0.5)
415+
if should_stop(runtimes, window_size, session.config.option.codeflash_min_loops):
416+
break
417+
307418
if self._timed_out(session, start_time, count):
308-
break # exit loop
419+
break
420+
309421
_ORIGINAL_TIME_SLEEP(self._get_delay_time(session))
310422
return True
311423

codeflash/verification/test_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def run_benchmarking_tests(
212212
f"--codeflash_min_loops={pytest_min_loops}",
213213
f"--codeflash_max_loops={pytest_max_loops}",
214214
f"--codeflash_seconds={pytest_target_runtime_seconds}",
215+
"--codeflash_stability_check=true",
215216
]
216217
if pytest_timeout is not None:
217218
pytest_args.append(f"--timeout={pytest_timeout}")

0 commit comments

Comments
 (0)