22
33import sqlite3
44import textwrap
5- from collections . abc import Generator
6- from typing import Any , Dict
5+ from pathlib import Path
6+ from typing import TYPE_CHECKING , Any
77
88import isort
99
1010from codeflash .cli_cmds .console import logger
1111from codeflash .discovery .functions_to_optimize import inspect_top_level_functions_or_methods
1212from codeflash .verification .verification_utils import get_test_file_path
13- from pathlib import Path
13+
14+ if TYPE_CHECKING :
15+ from collections .abc import Generator
16+
1417
1518def get_next_arg_and_return (
16- trace_file : str , function_name : str , file_name : str , class_name : str | None = None , num_to_get : int = 256
19+ trace_file : str , function_name : str , file_path : str , class_name : str | None = None , num_to_get : int = 256
1720) -> Generator [Any ]:
1821 db = sqlite3 .connect (trace_file )
1922 cur = db .cursor ()
2023 limit = num_to_get
2124
2225 if class_name is not None :
2326 cursor = cur .execute (
24- "SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_name = ? AND class_name = ? LIMIT ?" ,
25- (function_name , file_name , class_name , limit ),
27+ "SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_path = ? AND class_name = ? LIMIT ?" ,
28+ (function_name , file_path , class_name , limit ),
2629 )
2730 else :
2831 cursor = cur .execute (
29- "SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_name = ? AND class_name = '' LIMIT ?" ,
30- (function_name , file_name , limit ),
32+ "SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_path = ? AND class_name = '' LIMIT ?" ,
33+ (function_name , file_path , limit ),
3134 )
3235
3336 while (val := cursor .fetchone ()) is not None :
@@ -88,7 +91,7 @@ def create_trace_replay_test_code(
8891 # Templates for different types of tests
8992 test_function_body = textwrap .dedent (
9093 """\
91- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name =r"{file_name }", num_to_get={max_run_count}):
94+ for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path =r"{file_path }", num_to_get={max_run_count}):
9295 args = pickle.loads(args_pkl)
9396 kwargs = pickle.loads(kwargs_pkl)
9497 ret = {function_name}(*args, **kwargs)
@@ -97,7 +100,7 @@ def create_trace_replay_test_code(
97100
98101 test_method_body = textwrap .dedent (
99102 """\
100- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name =r"{file_name }", class_name="{class_name}", num_to_get={max_run_count}):
103+ for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path =r"{file_path }", class_name="{class_name}", num_to_get={max_run_count}):
101104 args = pickle.loads(args_pkl)
102105 kwargs = pickle.loads(kwargs_pkl){filter_variables}
103106 function_name = "{orig_function_name}"
@@ -112,7 +115,7 @@ def create_trace_replay_test_code(
112115
113116 test_class_method_body = textwrap .dedent (
114117 """\
115- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name =r"{file_name }", class_name="{class_name}", num_to_get={max_run_count}):
118+ for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path =r"{file_path }", class_name="{class_name}", num_to_get={max_run_count}):
116119 args = pickle.loads(args_pkl)
117120 kwargs = pickle.loads(kwargs_pkl){filter_variables}
118121 if not args:
@@ -122,7 +125,7 @@ def create_trace_replay_test_code(
122125 )
123126 test_static_method_body = textwrap .dedent (
124127 """\
125- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_name =r"{file_name }", class_name="{class_name}", num_to_get={max_run_count}):
128+ for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path =r"{file_path }", class_name="{class_name}", num_to_get={max_run_count}):
126129 args = pickle.loads(args_pkl)
127130 kwargs = pickle.loads(kwargs_pkl){filter_variables}
128131 ret = {class_name_alias}{method_name}(*args, **kwargs)
@@ -140,13 +143,13 @@ def create_trace_replay_test_code(
140143 module_name = func .get ("module_name" )
141144 function_name = func .get ("function_name" )
142145 class_name = func .get ("class_name" )
143- file_name = func .get ("file_name " )
146+ file_path = func .get ("file_path " )
144147 function_properties = func .get ("function_properties" )
145148 if not class_name :
146149 alias = get_function_alias (module_name , function_name )
147150 test_body = test_function_body .format (
148151 function_name = alias ,
149- file_name = file_name ,
152+ file_path = file_path ,
150153 orig_function_name = function_name ,
151154 max_run_count = max_run_count ,
152155 )
@@ -160,7 +163,7 @@ def create_trace_replay_test_code(
160163 if function_properties .is_classmethod :
161164 test_body = test_class_method_body .format (
162165 orig_function_name = function_name ,
163- file_name = file_name ,
166+ file_path = file_path ,
164167 class_name_alias = class_name_alias ,
165168 class_name = class_name ,
166169 method_name = method_name ,
@@ -170,7 +173,7 @@ def create_trace_replay_test_code(
170173 elif function_properties .is_staticmethod :
171174 test_body = test_static_method_body .format (
172175 orig_function_name = function_name ,
173- file_name = file_name ,
176+ file_path = file_path ,
174177 class_name_alias = class_name_alias ,
175178 class_name = class_name ,
176179 method_name = method_name ,
@@ -180,7 +183,7 @@ def create_trace_replay_test_code(
180183 else :
181184 test_body = test_method_body .format (
182185 orig_function_name = function_name ,
183- file_name = file_name ,
186+ file_path = file_path ,
184187 class_name_alias = class_name_alias ,
185188 class_name = class_name ,
186189 method_name = method_name ,
@@ -216,42 +219,41 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
216219
217220 # Get distinct benchmark names
218221 cursor .execute (
219- "SELECT DISTINCT benchmark_function_name, benchmark_file_name FROM benchmark_function_timings"
222+ "SELECT DISTINCT benchmark_function_name, benchmark_file_path FROM benchmark_function_timings"
220223 )
221224 benchmarks = cursor .fetchall ()
222225
223226 # Generate a test for each benchmark
224227 for benchmark in benchmarks :
225- benchmark_function_name , benchmark_file_name = benchmark
228+ benchmark_function_name , benchmark_file_path = benchmark
226229 # Get functions associated with this benchmark
227230 cursor .execute (
228- "SELECT DISTINCT function_name, class_name, module_name, file_name , benchmark_line_number FROM benchmark_function_timings "
229- "WHERE benchmark_function_name = ? AND benchmark_file_name = ?" ,
230- (benchmark_function_name , benchmark_file_name )
231+ "SELECT DISTINCT function_name, class_name, module_name, file_path , benchmark_line_number FROM benchmark_function_timings "
232+ "WHERE benchmark_function_name = ? AND benchmark_file_path = ?" ,
233+ (benchmark_function_name , benchmark_file_path )
231234 )
232235
233236 functions_data = []
234237 for func_row in cursor .fetchall ():
235- function_name , class_name , module_name , file_name , benchmark_line_number = func_row
236-
238+ function_name , class_name , module_name , file_path , benchmark_line_number = func_row
237239 # Add this function to our list
238240 functions_data .append ({
239241 "function_name" : function_name ,
240242 "class_name" : class_name ,
241- "file_name " : file_name ,
243+ "file_path " : file_path ,
242244 "module_name" : module_name ,
243245 "benchmark_function_name" : benchmark_function_name ,
244- "benchmark_file_name " : benchmark_file_name ,
246+ "benchmark_file_path " : benchmark_file_path ,
245247 "benchmark_line_number" : benchmark_line_number ,
246248 "function_properties" : inspect_top_level_functions_or_methods (
247- file_name = file_name ,
249+ file_name = Path ( file_path ) ,
248250 function_or_method_name = function_name ,
249251 class_name = class_name ,
250252 )
251253 })
252254
253255 if not functions_data :
254- logger .info (f"No functions found for benchmark { benchmark_function_name } in { benchmark_file_name } " )
256+ logger .info (f"No functions found for benchmark { benchmark_function_name } in { benchmark_file_path } " )
255257 continue
256258
257259 # Generate the test code for this benchmark
@@ -265,17 +267,19 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
265267
266268 # Write to file if requested
267269 if output_dir :
270+ name = Path (benchmark_file_path ).name .split ("." )[0 ][5 :] # remove "test_" from the name since we add it in later
268271 output_file = get_test_file_path (
269- test_dir = Path (output_dir ), function_name = f"{ benchmark_file_name } _{ benchmark_function_name } " , test_type = "replay"
272+ test_dir = Path (output_dir ), function_name = f"{ name } _{ benchmark_function_name } " , test_type = "replay"
270273 )
271274 # Write test code to file, parents = true
272275 output_dir .mkdir (parents = True , exist_ok = True )
273276 output_file .write_text (test_code , "utf-8" )
274277 count += 1
275- logger .info (f"Replay test for benchmark `{ benchmark_function_name } ` in { benchmark_file_name } written to { output_file } " )
278+ logger .info (f"Replay test for benchmark `{ benchmark_function_name } ` in { name } written to { output_file } " )
276279
277280 conn .close ()
278281
279282 except Exception as e :
280283 logger .info (f"Error generating replay tests: { e } " )
284+
281285 return count
0 commit comments