1616
1717
1818def get_next_arg_and_return (
19- trace_file : str , function_name : str , file_path : str , class_name : str | None = None , num_to_get : int = 256
19+ trace_file : str , benchmark_function_name : str , function_name : str , file_path : str , class_name : str | None = None , num_to_get : int = 256
2020) -> Generator [Any ]:
2121 db = sqlite3 .connect (trace_file )
2222 cur = db .cursor ()
2323 limit = num_to_get
2424
2525 if class_name is not None :
2626 cursor = cur .execute (
27- "SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_path = ? AND class_name = ? LIMIT ?" ,
28- (function_name , file_path , class_name , limit ),
27+ "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?" ,
28+ (benchmark_function_name , function_name , file_path , class_name , limit ),
2929 )
3030 else :
3131 cursor = cur .execute (
32- "SELECT * FROM benchmark_function_timings WHERE function_name = ? AND file_path = ? AND class_name = '' LIMIT ?" ,
33- (function_name , file_path , limit ),
32+ "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?" ,
33+ (benchmark_function_name , function_name , file_path , limit ),
3434 )
3535
3636 while (val := cursor .fetchone ()) is not None :
@@ -61,6 +61,7 @@ def create_trace_replay_test_code(
6161 """
6262 assert test_framework in ["pytest" , "unittest" ]
6363
64+ # Create Imports
6465 imports = f"""import dill as pickle
6566{ "import unittest" if test_framework == "unittest" else "" }
6667from codeflash.benchmarking.replay_test import get_next_arg_and_return
@@ -82,16 +83,15 @@ def create_trace_replay_test_code(
8283
8384 imports += "\n " .join (function_imports )
8485
85- functions_to_optimize = [ func .get ("function_name" ) for func in functions_data
86- if func .get ("function_name" ) != "__init__" ]
86+ functions_to_optimize = sorted ({ func .get ("function_name" ) for func in functions_data
87+ if func .get ("function_name" ) != "__init__" })
8788 metadata = f"""functions = { functions_to_optimize }
8889trace_file_path = r"{ trace_file } "
8990"""
90-
9191 # Templates for different types of tests
9292 test_function_body = textwrap .dedent (
9393 """\
94- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
94+ for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", num_to_get={max_run_count}):
9595 args = pickle.loads(args_pkl)
9696 kwargs = pickle.loads(kwargs_pkl)
9797 ret = {function_name}(*args, **kwargs)
@@ -100,7 +100,7 @@ def create_trace_replay_test_code(
100100
101101 test_method_body = textwrap .dedent (
102102 """\
103- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
103+ for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
104104 args = pickle.loads(args_pkl)
105105 kwargs = pickle.loads(kwargs_pkl){filter_variables}
106106 function_name = "{orig_function_name}"
@@ -115,7 +115,7 @@ def create_trace_replay_test_code(
115115
116116 test_class_method_body = textwrap .dedent (
117117 """\
118- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
118+ for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
119119 args = pickle.loads(args_pkl)
120120 kwargs = pickle.loads(kwargs_pkl){filter_variables}
121121 if not args:
@@ -125,13 +125,15 @@ def create_trace_replay_test_code(
125125 )
126126 test_static_method_body = textwrap .dedent (
127127 """\
128- for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
128+ for args_pkl, kwargs_pkl in get_next_arg_and_return(trace_file=trace_file_path, benchmark_function_name="{benchmark_function_name}", function_name="{orig_function_name}", file_path=r"{file_path}", class_name="{class_name}", num_to_get={max_run_count}):
129129 args = pickle.loads(args_pkl)
130130 kwargs = pickle.loads(kwargs_pkl){filter_variables}
131131 ret = {class_name_alias}{method_name}(*args, **kwargs)
132132 """
133133 )
134134
135+ # Create main body
136+
135137 if test_framework == "unittest" :
136138 self = "self"
137139 test_template = "\n class TestTracedFunctions(unittest.TestCase):\n "
@@ -140,17 +142,20 @@ def create_trace_replay_test_code(
140142 self = ""
141143
142144 for func in functions_data :
145+
143146 module_name = func .get ("module_name" )
144147 function_name = func .get ("function_name" )
145148 class_name = func .get ("class_name" )
146149 file_path = func .get ("file_path" )
150+ benchmark_function_name = func .get ("benchmark_function_name" )
147151 function_properties = func .get ("function_properties" )
148152 if not class_name :
149153 alias = get_function_alias (module_name , function_name )
150154 test_body = test_function_body .format (
155+ benchmark_function_name = benchmark_function_name ,
156+ orig_function_name = function_name ,
151157 function_name = alias ,
152158 file_path = file_path ,
153- orig_function_name = function_name ,
154159 max_run_count = max_run_count ,
155160 )
156161 else :
@@ -162,6 +167,7 @@ def create_trace_replay_test_code(
162167 method_name = "." + function_name if function_name != "__init__" else ""
163168 if function_properties .is_classmethod :
164169 test_body = test_class_method_body .format (
170+ benchmark_function_name = benchmark_function_name ,
165171 orig_function_name = function_name ,
166172 file_path = file_path ,
167173 class_name_alias = class_name_alias ,
@@ -172,6 +178,7 @@ def create_trace_replay_test_code(
172178 )
173179 elif function_properties .is_staticmethod :
174180 test_body = test_static_method_body .format (
181+ benchmark_function_name = benchmark_function_name ,
175182 orig_function_name = function_name ,
176183 file_path = file_path ,
177184 class_name_alias = class_name_alias ,
@@ -182,6 +189,7 @@ def create_trace_replay_test_code(
182189 )
183190 else :
184191 test_body = test_method_body .format (
192+ benchmark_function_name = benchmark_function_name ,
185193 orig_function_name = function_name ,
186194 file_path = file_path ,
187195 class_name_alias = class_name_alias ,
@@ -217,25 +225,25 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
217225 conn = sqlite3 .connect (trace_file_path .as_posix ())
218226 cursor = conn .cursor ()
219227
220- # Get distinct benchmark names
228+ # Get distinct benchmark file paths
221229 cursor .execute (
222- "SELECT DISTINCT benchmark_function_name, benchmark_file_path FROM benchmark_function_timings"
230+ "SELECT DISTINCT benchmark_file_path FROM benchmark_function_timings"
223231 )
224- benchmarks = cursor .fetchall ()
232+ benchmark_files = cursor .fetchall ()
225233
226- # Generate a test for each benchmark
227- for benchmark in benchmarks :
228- benchmark_function_name , benchmark_file_path = benchmark
229- # Get functions associated with this benchmark
234+ # Generate a test for each benchmark file
235+ for benchmark_file in benchmark_files :
236+ benchmark_file_path = benchmark_file [ 0 ]
237+ # Get all benchmarks and functions associated with this file path
230238 cursor .execute (
231- "SELECT DISTINCT function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
232- "WHERE benchmark_function_name = ? AND benchmark_file_path = ?" ,
233- (benchmark_function_name , benchmark_file_path )
239+ "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
240+ "WHERE benchmark_file_path = ?" ,
241+ (benchmark_file_path , )
234242 )
235243
236244 functions_data = []
237- for func_row in cursor .fetchall ():
238- function_name , class_name , module_name , file_path , benchmark_line_number = func_row
245+ for row in cursor .fetchall ():
246+ benchmark_function_name , function_name , class_name , module_name , file_path , benchmark_line_number = row
239247 # Add this function to our list
240248 functions_data .append ({
241249 "function_name" : function_name ,
@@ -246,16 +254,15 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
246254 "benchmark_file_path" : benchmark_file_path ,
247255 "benchmark_line_number" : benchmark_line_number ,
248256 "function_properties" : inspect_top_level_functions_or_methods (
249- file_name = Path (file_path ),
250- function_or_method_name = function_name ,
251- class_name = class_name ,
252- )
257+ file_name = Path (file_path ),
258+ function_or_method_name = function_name ,
259+ class_name = class_name ,
260+ )
253261 })
254262
255263 if not functions_data :
256- logger .info (f"No functions found for benchmark { benchmark_function_name } in { benchmark_file_path } " )
264+ logger .info (f"No benchmark test functions found in { benchmark_file_path } " )
257265 continue
258-
259266 # Generate the test code for this benchmark
260267 test_code = create_trace_replay_test_code (
261268 trace_file = trace_file_path .as_posix (),
@@ -265,17 +272,15 @@ def generate_replay_test(trace_file_path: Path, output_dir: Path, test_framework
265272 )
266273 test_code = isort .code (test_code )
267274
268- # Write to file if requested
269- if output_dir :
270- name = Path (benchmark_file_path ).name .split ("." )[0 ][5 :] # remove "test_" from the name since we add it in later
271- output_file = get_test_file_path (
272- test_dir = Path (output_dir ), function_name = f"{ name } _{ benchmark_function_name } " , test_type = "replay"
273- )
274- # Write test code to file, parents = true
275- output_dir .mkdir (parents = True , exist_ok = True )
276- output_file .write_text (test_code , "utf-8" )
277- count += 1
278- logger .info (f"Replay test for benchmark `{ benchmark_function_name } ` in { name } written to { output_file } " )
275+ name = Path (benchmark_file_path ).name .split ("." )[0 ][5 :] # remove "test_" from the name since we add it in later
276+ output_file = get_test_file_path (
277+ test_dir = Path (output_dir ), function_name = f"{ name } " , test_type = "replay"
278+ )
279+ # Write test code to file, parents = true
280+ output_dir .mkdir (parents = True , exist_ok = True )
281+ output_file .write_text (test_code , "utf-8" )
282+ count += 1
283+ logger .info (f"Replay test for benchmark file `{ benchmark_file_path } ` in { name } written to { output_file } " )
279284
280285 conn .close ()
281286
0 commit comments