33from collections import defaultdict
44from typing import TYPE_CHECKING
55
6+ from pydantic import BaseModel
67from rich .tree import Tree
78
8- from codeflash .cli_cmds .console import DEBUG_MODE
9+ from codeflash .cli_cmds .console import DEBUG_MODE , logger
910
1011if TYPE_CHECKING :
1112 from collections .abc import Iterator
1920from typing import Annotated , Optional , cast
2021
2122from jedi .api .classes import Name
22- from pydantic import AfterValidator , BaseModel , ConfigDict , Field
23+ from pydantic import AfterValidator , ConfigDict , Field
2324from pydantic .dataclasses import dataclass
2425
25- from codeflash .cli_cmds .console import console , logger
26+ from codeflash .cli_cmds .console import console
2627from codeflash .code_utils .code_utils import module_name_from_file_path , validate_python_code
2728from codeflash .code_utils .env_utils import is_end_to_end
2829from codeflash .verification .comparator import comparator
@@ -59,24 +60,29 @@ class FunctionSource:
5960 def __eq__ (self , other : object ) -> bool :
6061 if not isinstance (other , FunctionSource ):
6162 return False
62- return (self .file_path == other .file_path and
63- self .qualified_name == other .qualified_name and
64- self .fully_qualified_name == other .fully_qualified_name and
65- self .only_function_name == other .only_function_name and
66- self .source_code == other .source_code )
63+ return (
64+ self .file_path == other .file_path
65+ and self .qualified_name == other .qualified_name
66+ and self .fully_qualified_name == other .fully_qualified_name
67+ and self .only_function_name == other .only_function_name
68+ and self .source_code == other .source_code
69+ )
6770
6871 def __hash__ (self ) -> int :
69- return hash ((self .file_path , self .qualified_name , self .fully_qualified_name ,
70- self .only_function_name , self .source_code ))
72+ return hash (
73+ (self .file_path , self .qualified_name , self .fully_qualified_name , self .only_function_name , self .source_code )
74+ )
75+
7176
7277class BestOptimization (BaseModel ):
7378 candidate : OptimizedCandidate
7479 helper_functions : list [FunctionSource ]
7580 runtime : int
76- replay_performance_gain : Optional [dict [BenchmarkKey ,float ]] = None
81+ replay_performance_gain : Optional [dict [BenchmarkKey , float ]] = None
7782 winning_behavioral_test_results : TestResults
7883 winning_benchmarking_test_results : TestResults
79- winning_replay_benchmarking_test_results : Optional [TestResults ] = None
84+ winning_replay_benchmarking_test_results : Optional [TestResults ] = None
85+
8086
8187@dataclass (frozen = True )
8288class BenchmarkKey :
@@ -86,6 +92,7 @@ class BenchmarkKey:
8692 def __str__ (self ) -> str :
8793 return f"{ self .module_path } ::{ self .function_name } "
8894
95+
8996@dataclass
9097class BenchmarkDetail :
9198 benchmark_name : str
@@ -107,9 +114,10 @@ def to_dict(self) -> dict[str, any]:
107114 "test_function" : self .test_function ,
108115 "original_timing" : self .original_timing ,
109116 "expected_new_timing" : self .expected_new_timing ,
110- "speedup_percent" : self .speedup_percent
117+ "speedup_percent" : self .speedup_percent ,
111118 }
112119
120+
113121@dataclass
114122class ProcessedBenchmarkInfo :
115123 benchmark_details : list [BenchmarkDetail ]
@@ -124,9 +132,9 @@ def to_string(self) -> str:
124132 return result
125133
126134 def to_dict (self ) -> dict [str , list [dict [str , any ]]]:
127- return {
128- "benchmark_details" : [ detail . to_dict () for detail in self . benchmark_details ]
129- }
135+ return {"benchmark_details" : [ detail . to_dict () for detail in self . benchmark_details ]}
136+
137+
130138class CodeString (BaseModel ):
131139 code : Annotated [str , AfterValidator (validate_python_code )]
132140 file_path : Optional [Path ] = None
@@ -151,7 +159,8 @@ class CodeOptimizationContext(BaseModel):
151159 read_writable_code : str = Field (min_length = 1 )
152160 read_only_context_code : str = ""
153161 helper_functions : list [FunctionSource ]
154- preexisting_objects : set [tuple [str , tuple [FunctionParent ,...]]]
162+ preexisting_objects : set [tuple [str , tuple [FunctionParent , ...]]]
163+
155164
156165class CodeContextType (str , Enum ):
157166 READ_WRITABLE = "READ_WRITABLE"
@@ -347,6 +356,7 @@ def create_empty(cls, file_path: Path, function_name: str, code_context: CodeOpt
347356 status = CoverageStatus .NOT_FOUND ,
348357 )
349358
359+
350360@dataclass
351361class FunctionCoverage :
352362 """Represents the coverage data for a specific function in a source file."""
@@ -364,7 +374,8 @@ class TestingMode(enum.Enum):
364374 PERFORMANCE = "performance"
365375 LINE_PROFILE = "line_profile"
366376
367- #TODO this class is duplicated in codeflash_capture
377+
378+ # TODO this class is duplicated in codeflash_capture
368379class VerificationType (str , Enum ):
369380 FUNCTION_CALL = (
370381 "function_call" # Correctness verification for a test function, checks input values and output values)
@@ -473,14 +484,20 @@ def merge(self, other: TestResults) -> None:
473484 raise ValueError (msg )
474485 self .test_result_idx [k ] = v + original_len
475486
476- def group_by_benchmarks (self , benchmark_keys :list [BenchmarkKey ], benchmark_replay_test_dir : Path , project_root : Path ) -> dict [BenchmarkKey , TestResults ]:
487+ def group_by_benchmarks (
488+ self , benchmark_keys : list [BenchmarkKey ], benchmark_replay_test_dir : Path , project_root : Path
489+ ) -> dict [BenchmarkKey , TestResults ]:
477490 """Group TestResults by benchmark for calculating improvements for each benchmark."""
478491 test_results_by_benchmark = defaultdict (TestResults )
479492 benchmark_module_path = {}
480493 for benchmark_key in benchmark_keys :
481- benchmark_module_path [benchmark_key ] = module_name_from_file_path (benchmark_replay_test_dir .resolve () / f"test_{ benchmark_key .module_path .replace ('.' , '_' )} __replay_test_" , project_root )
494+ benchmark_module_path [benchmark_key ] = module_name_from_file_path (
495+ benchmark_replay_test_dir .resolve ()
496+ / f"test_{ benchmark_key .module_path .replace ('.' , '_' )} __replay_test_" ,
497+ project_root ,
498+ )
482499 for test_result in self .test_results :
483- if ( test_result .test_type == TestType .REPLAY_TEST ) :
500+ if test_result .test_type == TestType .REPLAY_TEST :
484501 for benchmark_key , module_path in benchmark_module_path .items ():
485502 if test_result .id .test_module_path .startswith (module_path ):
486503 test_results_by_benchmark [benchmark_key ].add (test_result )
@@ -537,22 +554,20 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree:
537554 return tree
538555
539556 def usable_runtime_data_by_test_case (self ) -> dict [InvocationId , list [int ]]:
557+ # Efficient single traversal, directly accumulating into a dict.
558+ by_id : dict [InvocationId , list [int ]] = {}
540559 for result in self .test_results :
541- if result .did_pass and not result .runtime :
542- msg = (
543- f"Ignoring test case that passed but had no runtime -> { result .id } , "
544- f"Loop # { result .loop_index } , Test Type: { result .test_type } , "
545- f"Verification Type: { result .verification_type } "
546- )
547- logger .debug (msg )
548-
549- usable_runtimes = [
550- (result .id , result .runtime ) for result in self .test_results if result .did_pass and result .runtime
551- ]
552- return {
553- usable_id : [runtime [1 ] for runtime in usable_runtimes if runtime [0 ] == usable_id ]
554- for usable_id in {runtime [0 ] for runtime in usable_runtimes }
555- }
560+ if result .did_pass :
561+ if result .runtime :
562+ by_id .setdefault (result .id , []).append (result .runtime )
563+ else :
564+ msg = (
565+ f"Ignoring test case that passed but had no runtime -> { result .id } , "
566+ f"Loop # { result .loop_index } , Test Type: { result .test_type } , "
567+ f"Verification Type: { result .verification_type } "
568+ )
569+ logger .debug (msg )
570+ return by_id
556571
557572 def total_passed_runtime (self ) -> int :
558573 """Calculate the sum of runtimes of all test cases that passed.
@@ -561,10 +576,9 @@ def total_passed_runtime(self) -> int:
561576
562577 :return: The runtime in nanoseconds.
563578 """
564- #TODO this doesn't look at the intersection of tests of baseline and original
565- return sum (
566- [min (usable_runtime_data ) for _ , usable_runtime_data in self .usable_runtime_data_by_test_case ().items ()]
567- )
579+ # TODO this doesn't look at the intersection of tests of baseline and original
580+ runtime_data = self .usable_runtime_data_by_test_case ()
581+ return sum (min (times ) for times in runtime_data .values () if times )
568582
569583 def __iter__ (self ) -> Iterator [FunctionTestInvocation ]:
570584 return iter (self .test_results )
@@ -591,7 +605,7 @@ def __eq__(self, other: object) -> bool:
591605 if len (self ) != len (other ):
592606 return False
593607 original_recursion_limit = sys .getrecursionlimit ()
594- cast (TestResults , other )
608+ cast (" TestResults" , other )
595609 for test_result in self :
596610 other_test_result = other .get_by_unique_invocation_loop_id (test_result .unique_invocation_loop_id )
597611 if other_test_result is None :
0 commit comments