Skip to content

Commit 83dff02

Browse files
best summed runtime helper
1 parent 95f22ee commit 83dff02

File tree

3 files changed

+9
-7
lines changed

3 files changed

+9
-7
lines changed

codeflash/models/models.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table
1111
from codeflash.lsp.lsp_message import LspMarkdownMessage
1212
from codeflash.models.test_type import TestType
13+
from codeflash.result.best_summed_runtime import calculate_best_summed_runtime
1314

1415
if TYPE_CHECKING:
1516
from collections.abc import Iterator
@@ -817,9 +818,7 @@ def total_passed_runtime(self) -> int:
817818
:return: The runtime in nanoseconds.
818819
"""
819820
# TODO this doesn't look at the intersection of tests of baseline and original
820-
return sum(
821-
[min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()]
822-
)
821+
return calculate_best_summed_runtime(self.usable_runtime_data_by_test_case())
823822

824823
def file_to_no_of_tests(self, test_functions_to_remove: list[str]) -> Counter[Path]:
825824
map_gen_test_file_to_no_of_tests = Counter()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def calculate_best_summed_runtime(grouped_runtime_info: dict[any, list[int]]) -> int:
2+
return sum([min(usable_runtime_data) for _, usable_runtime_data in grouped_runtime_info.items()])

codeflash/verification/pytest_plugin.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import pytest
2121
from pluggy import HookspecMarker
2222

23+
from codeflash.result.best_summed_runtime import calculate_best_summed_runtime
24+
2325
if TYPE_CHECKING:
2426
from _pytest.config import Config, Parser
2527
from _pytest.main import Session
@@ -345,7 +347,8 @@ def dynamic_tolerance(self, avg_ns: float) -> float: # noqa: PLR0911
345347

346348
@pytest.hookimpl
347349
def pytest_runtest_logreport(self, report: pytest.TestReport) -> None:
348-
if report.when == "call" and report.passed and (duration_ns := get_runtime_from_stdout(report.capstdout)): # noqa: SIM102
350+
if report.when == "call" and report.passed:
351+
duration_ns = get_runtime_from_stdout(report.capstdout)
349352
if duration_ns:
350353
clean_id = re.sub(r"\s*\[\s*\d+\s*\]\s*$", "", report.nodeid)
351354
self.usable_runtime_data_by_test_case.setdefault(clean_id, []).append(duration_ns)
@@ -384,9 +387,7 @@ def pytest_runtestloop(self, session: Session) -> bool:
384387
if session.shouldstop:
385388
raise session.Interrupted(session.shouldstop)
386389

387-
best_runtime_until_now = sum(
388-
[min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case.items()]
389-
)
390+
best_runtime_until_now = calculate_best_summed_runtime(self.usable_runtime_data_by_test_case)
390391
runtimes.append(best_runtime_until_now)
391392

392393
if should_stop(runtimes):

0 commit comments

Comments
 (0)