1717from codeflash .benchmarking .utils import print_benchmark_table , validate_and_format_benchmark_table
1818from codeflash .cli_cmds .console import console , logger , progress_bar
1919from codeflash .code_utils import env_utils
20+ from codeflash .code_utils .checkpoint import CodeflashRunCheckpoint , ask_should_use_checkpoint_get_functions
2021from codeflash .code_utils .code_replacer import normalize_code , normalize_node
2122from codeflash .code_utils .code_utils import get_run_tmp_file
2223from codeflash .code_utils .static_analysis import analyze_imported_modules , get_first_top_level_function_or_method_ast
@@ -52,6 +53,8 @@ def __init__(self, args: Namespace) -> None:
5253 self .experiment_id = os .getenv ("CODEFLASH_EXPERIMENT_ID" , None )
5354 self .local_aiservice_client = LocalAiServiceClient () if self .experiment_id else None
5455 self .replay_tests_dir = None
56+ self .functions_checkpoint : CodeflashRunCheckpoint | None = None
57+
5558 def create_function_optimizer (
5659 self ,
5760 function_to_optimize : FunctionToOptimize ,
@@ -71,7 +74,7 @@ def create_function_optimizer(
7174 args = self .args ,
7275 function_benchmark_timings = function_benchmark_timings if function_benchmark_timings else None ,
7376 total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else None ,
74- replay_tests_dir = self .replay_tests_dir
77+ replay_tests_dir = self .replay_tests_dir ,
7578 )
7679
7780 def run (self ) -> None :
@@ -83,7 +86,7 @@ def run(self) -> None:
8386 function_optimizer = None
8487 file_to_funcs_to_optimize : dict [Path , list [FunctionToOptimize ]]
8588 num_optimizable_functions : int
86-
89+ previous_checkpoint_functions = ask_should_use_checkpoint_get_functions ( self . args )
8790 # discover functions
8891 (file_to_funcs_to_optimize , num_optimizable_functions ) = get_functions_to_optimize (
8992 optimize_all = self .args .all ,
@@ -94,14 +97,12 @@ def run(self) -> None:
9497 ignore_paths = self .args .ignore_paths ,
9598 project_root = self .args .project_root ,
9699 module_root = self .args .module_root ,
100+ previous_checkpoint_functions = previous_checkpoint_functions ,
97101 )
98102 function_benchmark_timings : dict [str , dict [BenchmarkKey , int ]] = {}
99103 total_benchmark_timings : dict [BenchmarkKey , int ] = {}
100104 if self .args .benchmark and num_optimizable_functions > 0 :
101- with progress_bar (
102- f"Running benchmarks in { self .args .benchmarks_root } " ,
103- transient = True ,
104- ):
105+ with progress_bar (f"Running benchmarks in { self .args .benchmarks_root } " , transient = True ):
105106 # Insert decorator
106107 file_path_to_source_code = defaultdict (str )
107108 for file in file_to_funcs_to_optimize :
@@ -113,15 +114,23 @@ def run(self) -> None:
113114 if trace_file .exists ():
114115 trace_file .unlink ()
115116
116- self .replay_tests_dir = Path (tempfile .mkdtemp (prefix = "codeflash_replay_tests_" , dir = self .args .benchmarks_root ))
117- trace_benchmarks_pytest (self .args .benchmarks_root , self .args .tests_root , self .args .project_root , trace_file ) # Run all tests that use pytest-benchmark
117+ self .replay_tests_dir = Path (
118+ tempfile .mkdtemp (prefix = "codeflash_replay_tests_" , dir = self .args .benchmarks_root )
119+ )
120+ trace_benchmarks_pytest (
121+ self .args .benchmarks_root , self .args .tests_root , self .args .project_root , trace_file
122+ ) # Run all tests that use pytest-benchmark
118123 replay_count = generate_replay_test (trace_file , self .replay_tests_dir )
119124 if replay_count == 0 :
120- logger .info (f"No valid benchmarks found in { self .args .benchmarks_root } for functions to optimize, continuing optimization" )
125+ logger .info (
126+ f"No valid benchmarks found in { self .args .benchmarks_root } for functions to optimize, continuing optimization"
127+ )
121128 else :
122129 function_benchmark_timings = CodeFlashBenchmarkPlugin .get_function_benchmark_timings (trace_file )
123130 total_benchmark_timings = CodeFlashBenchmarkPlugin .get_benchmark_timings (trace_file )
124- function_to_results = validate_and_format_benchmark_table (function_benchmark_timings , total_benchmark_timings )
131+ function_to_results = validate_and_format_benchmark_table (
132+ function_benchmark_timings , total_benchmark_timings
133+ )
125134 print_benchmark_table (function_to_results )
126135 except Exception as e :
127136 logger .info (f"Error while tracing existing benchmarks: { e } " )
@@ -148,10 +157,13 @@ def run(self) -> None:
148157 function_to_tests : dict [str , list [FunctionCalledInTest ]] = discover_unit_tests (self .test_cfg )
149158 num_discovered_tests : int = sum ([len (value ) for value in function_to_tests .values ()])
150159 console .rule ()
151- logger .info (f"Discovered { num_discovered_tests } existing unit tests in { (time .time () - start_time ):.1f} s at { self .test_cfg .tests_root } " )
160+ logger .info (
161+ f"Discovered { num_discovered_tests } existing unit tests in { (time .time () - start_time ):.1f} s at { self .test_cfg .tests_root } "
162+ )
152163 console .rule ()
153164 ph ("cli-optimize-discovered-tests" , {"num_tests" : num_discovered_tests })
154-
165+ if self .args .all :
166+ self .functions_checkpoint = CodeflashRunCheckpoint (self .args .module_path )
155167
156168 for original_module_path in file_to_funcs_to_optimize :
157169 logger .info (f"Examining file { original_module_path !s} …" )
@@ -212,17 +224,33 @@ def run(self) -> None:
212224 qualified_name_w_module = function_to_optimize .qualified_name_with_modules_from_root (
213225 self .args .project_root
214226 )
215- if self .args .benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings :
227+ if (
228+ self .args .benchmark
229+ and function_benchmark_timings
230+ and qualified_name_w_module in function_benchmark_timings
231+ and total_benchmark_timings
232+ ):
216233 function_optimizer = self .create_function_optimizer (
217- function_to_optimize , function_to_optimize_ast , function_to_tests , validated_original_code [original_module_path ].source_code , function_benchmark_timings [qualified_name_w_module ], total_benchmark_timings
234+ function_to_optimize ,
235+ function_to_optimize_ast ,
236+ function_to_tests ,
237+ validated_original_code [original_module_path ].source_code ,
238+ function_benchmark_timings [qualified_name_w_module ],
239+ total_benchmark_timings ,
218240 )
219241 else :
220242 function_optimizer = self .create_function_optimizer (
221- function_to_optimize , function_to_optimize_ast , function_to_tests ,
222- validated_original_code [original_module_path ].source_code
243+ function_to_optimize ,
244+ function_to_optimize_ast ,
245+ function_to_tests ,
246+ validated_original_code [original_module_path ].source_code ,
223247 )
224248
225249 best_optimization = function_optimizer .optimize_function ()
250+ if self .functions_checkpoint :
251+ self .functions_checkpoint .add_function_to_checkpoint (
252+ function_to_optimize .qualified_name_with_modules_from_root (self .args .project_root )
253+ )
226254 if is_successful (best_optimization ):
227255 optimizations_found += 1
228256 else :
0 commit comments