@@ -47,6 +47,66 @@ def __init__(self, args: Namespace) -> None:
4747 self .functions_checkpoint : CodeflashRunCheckpoint | None = None
4848 self .current_function_optimizer : FunctionOptimizer | None = None
4949
50+ def run_benchmarks (
51+ self , file_to_funcs_to_optimize : dict [Path , list [FunctionToOptimize ]], num_optimizable_functions : int
52+ ) -> tuple [dict [str , dict [BenchmarkKey , float ]], dict [BenchmarkKey , float ]]:
53+ """Run benchmarks for the functions to optimize and collect timing information."""
54+ function_benchmark_timings : dict [str , dict [BenchmarkKey , float ]] = {}
55+ total_benchmark_timings : dict [BenchmarkKey , float ] = {}
56+
57+ if not (hasattr (self .args , "benchmark" ) and self .args .benchmark and num_optimizable_functions > 0 ):
58+ return function_benchmark_timings , total_benchmark_timings
59+
60+ from codeflash .benchmarking .instrument_codeflash_trace import instrument_codeflash_trace_decorator
61+ from codeflash .benchmarking .plugin .plugin import CodeFlashBenchmarkPlugin
62+ from codeflash .benchmarking .replay_test import generate_replay_test
63+ from codeflash .benchmarking .trace_benchmarks import trace_benchmarks_pytest
64+ from codeflash .benchmarking .utils import print_benchmark_table , validate_and_format_benchmark_table
65+ from codeflash .code_utils .env_utils import get_pr_number
66+
67+ with progress_bar (
68+ f"Running benchmarks in { self .args .benchmarks_root } " , transient = True , revert_to_print = bool (get_pr_number ())
69+ ):
70+ # Insert decorator
71+ file_path_to_source_code = defaultdict (str )
72+ for file in file_to_funcs_to_optimize :
73+ with file .open ("r" , encoding = "utf8" ) as f :
74+ file_path_to_source_code [file ] = f .read ()
75+ try :
76+ instrument_codeflash_trace_decorator (file_to_funcs_to_optimize )
77+ trace_file = Path (self .args .benchmarks_root ) / "benchmarks.trace"
78+ if trace_file .exists ():
79+ trace_file .unlink ()
80+
81+ self .replay_tests_dir = Path (
82+ tempfile .mkdtemp (prefix = "codeflash_replay_tests_" , dir = self .args .benchmarks_root )
83+ )
84+ trace_benchmarks_pytest (
85+ self .args .benchmarks_root , self .args .tests_root , self .args .project_root , trace_file
86+ ) # Run all tests that use pytest-benchmark
87+ replay_count = generate_replay_test (trace_file , self .replay_tests_dir )
88+ if replay_count == 0 :
89+ logger .info (
90+ f"No valid benchmarks found in { self .args .benchmarks_root } for functions to optimize, continuing optimization"
91+ )
92+ else :
93+ function_benchmark_timings = CodeFlashBenchmarkPlugin .get_function_benchmark_timings (trace_file )
94+ total_benchmark_timings = CodeFlashBenchmarkPlugin .get_benchmark_timings (trace_file )
95+ function_to_results = validate_and_format_benchmark_table (
96+ function_benchmark_timings , total_benchmark_timings
97+ )
98+ print_benchmark_table (function_to_results )
99+ except Exception as e :
100+ logger .info (f"Error while tracing existing benchmarks: { e } " )
101+ logger .info ("Information on existing benchmarks will not be available for this run." )
102+ finally :
103+ # Restore original source code
104+ for file in file_path_to_source_code :
105+ with file .open ("w" , encoding = "utf8" ) as f :
106+ f .write (file_path_to_source_code [file ])
107+
108+ return function_benchmark_timings , total_benchmark_timings
109+
50110 def create_function_optimizer (
51111 self ,
52112 function_to_optimize : FunctionToOptimize ,
@@ -108,58 +168,9 @@ def run(self) -> None:
108168 module_root = self .args .module_root ,
109169 previous_checkpoint_functions = self .args .previous_checkpoint_functions ,
110170 )
111- function_benchmark_timings : dict [str , dict [BenchmarkKey , int ]] = {}
112- total_benchmark_timings : dict [BenchmarkKey , int ] = {}
113- if self .args .benchmark and num_optimizable_functions > 0 :
114- from codeflash .benchmarking .instrument_codeflash_trace import instrument_codeflash_trace_decorator
115- from codeflash .benchmarking .plugin .plugin import CodeFlashBenchmarkPlugin
116- from codeflash .benchmarking .replay_test import generate_replay_test
117- from codeflash .benchmarking .trace_benchmarks import trace_benchmarks_pytest
118- from codeflash .benchmarking .utils import print_benchmark_table , validate_and_format_benchmark_table
119-
120- console .rule ()
121- with progress_bar (
122- f"Running benchmarks in { self .args .benchmarks_root } " ,
123- transient = True ,
124- revert_to_print = bool (get_pr_number ()),
125- ):
126- # Insert decorator
127- file_path_to_source_code = defaultdict (str )
128- for file in file_to_funcs_to_optimize :
129- with file .open ("r" , encoding = "utf8" ) as f :
130- file_path_to_source_code [file ] = f .read ()
131- try :
132- instrument_codeflash_trace_decorator (file_to_funcs_to_optimize )
133- trace_file = Path (self .args .benchmarks_root ) / "benchmarks.trace"
134- if trace_file .exists ():
135- trace_file .unlink ()
136-
137- self .replay_tests_dir = Path (
138- tempfile .mkdtemp (prefix = "codeflash_replay_tests_" , dir = self .args .tests_root )
139- )
140- trace_benchmarks_pytest (
141- self .args .benchmarks_root , self .args .tests_root , self .args .project_root , trace_file
142- ) # Run all tests that use pytest-benchmark
143- replay_count = generate_replay_test (trace_file , self .replay_tests_dir )
144- if replay_count == 0 :
145- logger .info (
146- f"No valid benchmarks found in { self .args .benchmarks_root } for functions to optimize, continuing optimization"
147- )
148- else :
149- function_benchmark_timings = CodeFlashBenchmarkPlugin .get_function_benchmark_timings (trace_file )
150- total_benchmark_timings = CodeFlashBenchmarkPlugin .get_benchmark_timings (trace_file )
151- function_to_results = validate_and_format_benchmark_table (
152- function_benchmark_timings , total_benchmark_timings
153- )
154- print_benchmark_table (function_to_results )
155- except Exception as e :
156- logger .info (f"Error while tracing existing benchmarks: { e } " )
157- logger .info ("Information on existing benchmarks will not be available for this run." )
158- finally :
159- # Restore original source code
160- for file in file_path_to_source_code :
161- with file .open ("w" , encoding = "utf8" ) as f :
162- f .write (file_path_to_source_code [file ])
171+ function_benchmark_timings , total_benchmark_timings = self .run_benchmarks (
172+ file_to_funcs_to_optimize , num_optimizable_functions
173+ )
163174 optimizations_found : int = 0
164175 function_iterator_count : int = 0
165176 if self .args .test_framework == "pytest" :
0 commit comments