Skip to content

Commit fd6812c

Browse files
⚡️ Speed up method TestResults.total_passed_runtime by 26%
Here’s a high-performance, memory-efficient rewrite. **Key optimizations**. - Avoids building large intermediate lists (`usable_runtimes` and dictionary comprehensions). - Uses a single pass to populate `by_id` and `debug_missing` data, accumulating lists in-place. - Skips repeated set/list comprehensions. - Reduces logging to a single loop. - Uses generator expressions for `sum`. - Preserves all original comments. **Notes:** - Results and log messages are produced in a single loop without intermediate list/dict comprehensions. - `.setdefault()` is the fastest way to accumulate into lists by ID. - This approach uses less memory and time, greatly improving performance for large numbers of test results. - All original comments are preserved.
1 parent 309695a commit fd6812c

File tree

1 file changed

+55
-41
lines changed

1 file changed

+55
-41
lines changed

codeflash/models/models.py

Lines changed: 55 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from collections import defaultdict
44
from typing import TYPE_CHECKING
55

6+
from pydantic import BaseModel
67
from rich.tree import Tree
78

8-
from codeflash.cli_cmds.console import DEBUG_MODE
9+
from codeflash.cli_cmds.console import DEBUG_MODE, logger
910

1011
if TYPE_CHECKING:
1112
from collections.abc import Iterator
@@ -19,10 +20,10 @@
1920
from typing import Annotated, Optional, cast
2021

2122
from jedi.api.classes import Name
22-
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
23+
from pydantic import AfterValidator, ConfigDict, Field
2324
from pydantic.dataclasses import dataclass
2425

25-
from codeflash.cli_cmds.console import console, logger
26+
from codeflash.cli_cmds.console import console
2627
from codeflash.code_utils.code_utils import module_name_from_file_path, validate_python_code
2728
from codeflash.code_utils.env_utils import is_end_to_end
2829
from codeflash.verification.comparator import comparator
@@ -59,24 +60,29 @@ class FunctionSource:
5960
def __eq__(self, other: object) -> bool:
6061
if not isinstance(other, FunctionSource):
6162
return False
62-
return (self.file_path == other.file_path and
63-
self.qualified_name == other.qualified_name and
64-
self.fully_qualified_name == other.fully_qualified_name and
65-
self.only_function_name == other.only_function_name and
66-
self.source_code == other.source_code)
63+
return (
64+
self.file_path == other.file_path
65+
and self.qualified_name == other.qualified_name
66+
and self.fully_qualified_name == other.fully_qualified_name
67+
and self.only_function_name == other.only_function_name
68+
and self.source_code == other.source_code
69+
)
6770

6871
def __hash__(self) -> int:
69-
return hash((self.file_path, self.qualified_name, self.fully_qualified_name,
70-
self.only_function_name, self.source_code))
72+
return hash(
73+
(self.file_path, self.qualified_name, self.fully_qualified_name, self.only_function_name, self.source_code)
74+
)
75+
7176

7277
class BestOptimization(BaseModel):
7378
candidate: OptimizedCandidate
7479
helper_functions: list[FunctionSource]
7580
runtime: int
76-
replay_performance_gain: Optional[dict[BenchmarkKey,float]] = None
81+
replay_performance_gain: Optional[dict[BenchmarkKey, float]] = None
7782
winning_behavioral_test_results: TestResults
7883
winning_benchmarking_test_results: TestResults
79-
winning_replay_benchmarking_test_results : Optional[TestResults] = None
84+
winning_replay_benchmarking_test_results: Optional[TestResults] = None
85+
8086

8187
@dataclass(frozen=True)
8288
class BenchmarkKey:
@@ -86,6 +92,7 @@ class BenchmarkKey:
8692
def __str__(self) -> str:
8793
return f"{self.module_path}::{self.function_name}"
8894

95+
8996
@dataclass
9097
class BenchmarkDetail:
9198
benchmark_name: str
@@ -107,9 +114,10 @@ def to_dict(self) -> dict[str, any]:
107114
"test_function": self.test_function,
108115
"original_timing": self.original_timing,
109116
"expected_new_timing": self.expected_new_timing,
110-
"speedup_percent": self.speedup_percent
117+
"speedup_percent": self.speedup_percent,
111118
}
112119

120+
113121
@dataclass
114122
class ProcessedBenchmarkInfo:
115123
benchmark_details: list[BenchmarkDetail]
@@ -124,9 +132,9 @@ def to_string(self) -> str:
124132
return result
125133

126134
def to_dict(self) -> dict[str, list[dict[str, any]]]:
127-
return {
128-
"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]
129-
}
135+
return {"benchmark_details": [detail.to_dict() for detail in self.benchmark_details]}
136+
137+
130138
class CodeString(BaseModel):
131139
code: Annotated[str, AfterValidator(validate_python_code)]
132140
file_path: Optional[Path] = None
@@ -151,7 +159,8 @@ class CodeOptimizationContext(BaseModel):
151159
read_writable_code: str = Field(min_length=1)
152160
read_only_context_code: str = ""
153161
helper_functions: list[FunctionSource]
154-
preexisting_objects: set[tuple[str, tuple[FunctionParent,...]]]
162+
preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]]
163+
155164

156165
class CodeContextType(str, Enum):
157166
READ_WRITABLE = "READ_WRITABLE"
@@ -347,6 +356,7 @@ def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOpt
347356
status=CoverageStatus.NOT_FOUND,
348357
)
349358

359+
350360
@dataclass
351361
class FunctionCoverage:
352362
"""Represents the coverage data for a specific function in a source file."""
@@ -364,7 +374,8 @@ class TestingMode(enum.Enum):
364374
PERFORMANCE = "performance"
365375
LINE_PROFILE = "line_profile"
366376

367-
#TODO this class is duplicated in codeflash_capture
377+
378+
# TODO this class is duplicated in codeflash_capture
368379
class VerificationType(str, Enum):
369380
FUNCTION_CALL = (
370381
"function_call" # Correctness verification for a test function, checks input values and output values)
@@ -473,14 +484,20 @@ def merge(self, other: TestResults) -> None:
473484
raise ValueError(msg)
474485
self.test_result_idx[k] = v + original_len
475486

476-
def group_by_benchmarks(self, benchmark_keys:list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path) -> dict[BenchmarkKey, TestResults]:
487+
def group_by_benchmarks(
488+
self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path
489+
) -> dict[BenchmarkKey, TestResults]:
477490
"""Group TestResults by benchmark for calculating improvements for each benchmark."""
478491
test_results_by_benchmark = defaultdict(TestResults)
479492
benchmark_module_path = {}
480493
for benchmark_key in benchmark_keys:
481-
benchmark_module_path[benchmark_key] = module_name_from_file_path(benchmark_replay_test_dir.resolve() / f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_", project_root)
494+
benchmark_module_path[benchmark_key] = module_name_from_file_path(
495+
benchmark_replay_test_dir.resolve()
496+
/ f"test_{benchmark_key.module_path.replace('.', '_')}__replay_test_",
497+
project_root,
498+
)
482499
for test_result in self.test_results:
483-
if (test_result.test_type == TestType.REPLAY_TEST):
500+
if test_result.test_type == TestType.REPLAY_TEST:
484501
for benchmark_key, module_path in benchmark_module_path.items():
485502
if test_result.id.test_module_path.startswith(module_path):
486503
test_results_by_benchmark[benchmark_key].add(test_result)
@@ -537,22 +554,20 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
537554
return tree
538555

539556
def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]:
557+
# Efficient single traversal, directly accumulating into a dict.
558+
by_id: dict[InvocationId, list[int]] = {}
540559
for result in self.test_results:
541-
if result.did_pass and not result.runtime:
542-
msg = (
543-
f"Ignoring test case that passed but had no runtime -> {result.id}, "
544-
f"Loop # {result.loop_index}, Test Type: {result.test_type}, "
545-
f"Verification Type: {result.verification_type}"
546-
)
547-
logger.debug(msg)
548-
549-
usable_runtimes = [
550-
(result.id, result.runtime) for result in self.test_results if result.did_pass and result.runtime
551-
]
552-
return {
553-
usable_id: [runtime[1] for runtime in usable_runtimes if runtime[0] == usable_id]
554-
for usable_id in {runtime[0] for runtime in usable_runtimes}
555-
}
560+
if result.did_pass:
561+
if result.runtime:
562+
by_id.setdefault(result.id, []).append(result.runtime)
563+
else:
564+
msg = (
565+
f"Ignoring test case that passed but had no runtime -> {result.id}, "
566+
f"Loop # {result.loop_index}, Test Type: {result.test_type}, "
567+
f"Verification Type: {result.verification_type}"
568+
)
569+
logger.debug(msg)
570+
return by_id
556571

557572
def total_passed_runtime(self) -> int:
558573
"""Calculate the sum of runtimes of all test cases that passed.
@@ -561,10 +576,9 @@ def total_passed_runtime(self) -> int:
561576
562577
:return: The runtime in nanoseconds.
563578
"""
564-
#TODO this doesn't look at the intersection of tests of baseline and original
565-
return sum(
566-
[min(usable_runtime_data) for _, usable_runtime_data in self.usable_runtime_data_by_test_case().items()]
567-
)
579+
# TODO this doesn't look at the intersection of tests of baseline and original
580+
runtime_data = self.usable_runtime_data_by_test_case()
581+
return sum(min(times) for times in runtime_data.values() if times)
568582

569583
def __iter__(self) -> Iterator[FunctionTestInvocation]:
570584
return iter(self.test_results)
@@ -591,7 +605,7 @@ def __eq__(self, other: object) -> bool:
591605
if len(self) != len(other):
592606
return False
593607
original_recursion_limit = sys.getrecursionlimit()
594-
cast(TestResults, other)
608+
cast("TestResults", other)
595609
for test_result in self:
596610
other_test_result = other.get_by_unique_invocation_loop_id(test_result.unique_invocation_loop_id)
597611
if other_test_result is None:

0 commit comments

Comments
 (0)