@@ -24,22 +24,25 @@ def get_next_arg_and_return(
2424 num_to_get : int = 256 ,
2525) -> Generator [Any ]:
2626 db = sqlite3 .connect (trace_file )
27- cur = db .cursor ()
28- limit = num_to_get
29-
30- if class_name is not None :
31- cursor = cur .execute (
32- "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?" ,
33- (benchmark_function_name , function_name , file_path , class_name , limit ),
34- )
35- else :
36- cursor = cur .execute (
37- "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?" ,
38- (benchmark_function_name , function_name , file_path , limit ),
39- )
27+ try :
28+ cur = db .cursor ()
29+ limit = num_to_get
30+
31+ if class_name is not None :
32+ cursor = cur .execute (
33+ "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?" ,
34+ (benchmark_function_name , function_name , file_path , class_name , limit ),
35+ )
36+ else :
37+ cursor = cur .execute (
38+ "SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?" ,
39+ (benchmark_function_name , function_name , file_path , limit ),
40+ )
4041
41- while (val := cursor .fetchone ()) is not None :
42- yield val [9 ], val [10 ] # pickled_args, pickled_kwargs
42+ while (val := cursor .fetchone ()) is not None :
43+ yield val [9 ], val [10 ] # pickled_args, pickled_kwargs
44+ finally :
45+ db .close ()
4346
4447
4548def get_function_alias (module : str , function_name : str ) -> str :
@@ -235,61 +238,69 @@ def generate_replay_test(
235238 try :
236239 # Connect to the database
237240 conn = sqlite3 .connect (trace_file_path .as_posix ())
238- cursor = conn .cursor ()
239-
240- # Get distinct benchmark file paths
241- cursor .execute ("SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings" )
242- benchmark_files = cursor .fetchall ()
243-
244- # Generate a test for each benchmark file
245- for benchmark_file in benchmark_files :
246- benchmark_module_path = benchmark_file [0 ]
247- # Get all benchmarks and functions associated with this file path
248- cursor .execute (
249- "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
250- "WHERE benchmark_module_path = ?" ,
251- (benchmark_module_path ,),
252- )
253-
254- functions_data = []
255- for row in cursor .fetchall ():
256- benchmark_function_name , function_name , class_name , module_name , file_path , benchmark_line_number = row
257- # Add this function to our list
258- functions_data .append (
259- {
260- "function_name" : function_name ,
261- "class_name" : class_name ,
262- "file_path" : file_path ,
263- "module_name" : module_name ,
264- "benchmark_function_name" : benchmark_function_name ,
265- "benchmark_module_path" : benchmark_module_path ,
266- "benchmark_line_number" : benchmark_line_number ,
267- "function_properties" : inspect_top_level_functions_or_methods (
268- file_name = Path (file_path ), function_or_method_name = function_name , class_name = class_name
269- ),
270- }
241+ try :
242+ cursor = conn .cursor ()
243+
244+ # Get distinct benchmark file paths
245+ cursor .execute ("SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings" )
246+ benchmark_files = cursor .fetchall ()
247+
248+ # Generate a test for each benchmark file
249+ for benchmark_file in benchmark_files :
250+ benchmark_module_path = benchmark_file [0 ]
251+ # Get all benchmarks and functions associated with this file path
252+ cursor .execute (
253+ "SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
254+ "WHERE benchmark_module_path = ?" ,
255+ (benchmark_module_path ,),
271256 )
272257
273- if not functions_data :
274- logger .info (f"No benchmark test functions found in { benchmark_module_path } " )
275- continue
276- # Generate the test code for this benchmark
277- test_code = create_trace_replay_test_code (
278- trace_file = trace_file_path .as_posix (),
279- functions_data = functions_data ,
280- test_framework = test_framework ,
281- max_run_count = max_run_count ,
282- )
283- test_code = isort .code (test_code )
284- output_file = get_test_file_path (
285- test_dir = Path (output_dir ), function_name = benchmark_module_path , test_type = "replay"
286- )
287- # Write test code to file, parents = true
288- output_dir .mkdir (parents = True , exist_ok = True )
289- output_file .write_text (test_code , "utf-8" )
290- count += 1
291-
292- conn .close ()
258+ functions_data = []
259+ for row in cursor .fetchall ():
260+ (
261+ benchmark_function_name ,
262+ function_name ,
263+ class_name ,
264+ module_name ,
265+ file_path ,
266+ benchmark_line_number ,
267+ ) = row
268+ # Add this function to our list
269+ functions_data .append (
270+ {
271+ "function_name" : function_name ,
272+ "class_name" : class_name ,
273+ "file_path" : file_path ,
274+ "module_name" : module_name ,
275+ "benchmark_function_name" : benchmark_function_name ,
276+ "benchmark_module_path" : benchmark_module_path ,
277+ "benchmark_line_number" : benchmark_line_number ,
278+ "function_properties" : inspect_top_level_functions_or_methods (
279+ file_name = Path (file_path ), function_or_method_name = function_name , class_name = class_name
280+ ),
281+ }
282+ )
283+
284+ if not functions_data :
285+ logger .info (f"No benchmark test functions found in { benchmark_module_path } " )
286+ continue
287+ # Generate the test code for this benchmark
288+ test_code = create_trace_replay_test_code (
289+ trace_file = trace_file_path .as_posix (),
290+ functions_data = functions_data ,
291+ test_framework = test_framework ,
292+ max_run_count = max_run_count ,
293+ )
294+ test_code = isort .code (test_code )
295+ output_file = get_test_file_path (
296+ test_dir = Path (output_dir ), function_name = benchmark_module_path , test_type = "replay"
297+ )
298+ # Write test code to file, parents = true
299+ output_dir .mkdir (parents = True , exist_ok = True )
300+ output_file .write_text (test_code , "utf-8" )
301+ count += 1
302+ finally :
303+ conn .close ()
293304 except Exception as e :
294305 logger .info (f"Error generating replay tests: { e } " )
295306
0 commit comments