Skip to content

Commit a388ee9

Browse files
committed
Add precautions to FD leaks2
1 parent bd05e72 commit a388ee9

File tree

2 files changed

+80
-70
lines changed

2 files changed

+80
-70
lines changed

codeflash/benchmarking/replay_test.py

Lines changed: 79 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,25 @@ def get_next_arg_and_return(
2424
num_to_get: int = 256,
2525
) -> Generator[Any]:
2626
db = sqlite3.connect(trace_file)
27-
cur = db.cursor()
28-
limit = num_to_get
29-
30-
if class_name is not None:
31-
cursor = cur.execute(
32-
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
33-
(benchmark_function_name, function_name, file_path, class_name, limit),
34-
)
35-
else:
36-
cursor = cur.execute(
37-
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
38-
(benchmark_function_name, function_name, file_path, limit),
39-
)
27+
try:
28+
cur = db.cursor()
29+
limit = num_to_get
30+
31+
if class_name is not None:
32+
cursor = cur.execute(
33+
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = ? LIMIT ?",
34+
(benchmark_function_name, function_name, file_path, class_name, limit),
35+
)
36+
else:
37+
cursor = cur.execute(
38+
"SELECT * FROM benchmark_function_timings WHERE benchmark_function_name = ? AND function_name = ? AND file_path = ? AND class_name = '' LIMIT ?",
39+
(benchmark_function_name, function_name, file_path, limit),
40+
)
4041

41-
while (val := cursor.fetchone()) is not None:
42-
yield val[9], val[10] # pickled_args, pickled_kwargs
42+
while (val := cursor.fetchone()) is not None:
43+
yield val[9], val[10] # pickled_args, pickled_kwargs
44+
finally:
45+
db.close()
4346

4447

4548
def get_function_alias(module: str, function_name: str) -> str:
@@ -235,61 +238,69 @@ def generate_replay_test(
235238
try:
236239
# Connect to the database
237240
conn = sqlite3.connect(trace_file_path.as_posix())
238-
cursor = conn.cursor()
239-
240-
# Get distinct benchmark file paths
241-
cursor.execute("SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings")
242-
benchmark_files = cursor.fetchall()
243-
244-
# Generate a test for each benchmark file
245-
for benchmark_file in benchmark_files:
246-
benchmark_module_path = benchmark_file[0]
247-
# Get all benchmarks and functions associated with this file path
248-
cursor.execute(
249-
"SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
250-
"WHERE benchmark_module_path = ?",
251-
(benchmark_module_path,),
252-
)
253-
254-
functions_data = []
255-
for row in cursor.fetchall():
256-
benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number = row
257-
# Add this function to our list
258-
functions_data.append(
259-
{
260-
"function_name": function_name,
261-
"class_name": class_name,
262-
"file_path": file_path,
263-
"module_name": module_name,
264-
"benchmark_function_name": benchmark_function_name,
265-
"benchmark_module_path": benchmark_module_path,
266-
"benchmark_line_number": benchmark_line_number,
267-
"function_properties": inspect_top_level_functions_or_methods(
268-
file_name=Path(file_path), function_or_method_name=function_name, class_name=class_name
269-
),
270-
}
241+
try:
242+
cursor = conn.cursor()
243+
244+
# Get distinct benchmark file paths
245+
cursor.execute("SELECT DISTINCT benchmark_module_path FROM benchmark_function_timings")
246+
benchmark_files = cursor.fetchall()
247+
248+
# Generate a test for each benchmark file
249+
for benchmark_file in benchmark_files:
250+
benchmark_module_path = benchmark_file[0]
251+
# Get all benchmarks and functions associated with this file path
252+
cursor.execute(
253+
"SELECT DISTINCT benchmark_function_name, function_name, class_name, module_name, file_path, benchmark_line_number FROM benchmark_function_timings "
254+
"WHERE benchmark_module_path = ?",
255+
(benchmark_module_path,),
271256
)
272257

273-
if not functions_data:
274-
logger.info(f"No benchmark test functions found in {benchmark_module_path}")
275-
continue
276-
# Generate the test code for this benchmark
277-
test_code = create_trace_replay_test_code(
278-
trace_file=trace_file_path.as_posix(),
279-
functions_data=functions_data,
280-
test_framework=test_framework,
281-
max_run_count=max_run_count,
282-
)
283-
test_code = isort.code(test_code)
284-
output_file = get_test_file_path(
285-
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
286-
)
287-
# Write test code to file, parents = true
288-
output_dir.mkdir(parents=True, exist_ok=True)
289-
output_file.write_text(test_code, "utf-8")
290-
count += 1
291-
292-
conn.close()
258+
functions_data = []
259+
for row in cursor.fetchall():
260+
(
261+
benchmark_function_name,
262+
function_name,
263+
class_name,
264+
module_name,
265+
file_path,
266+
benchmark_line_number,
267+
) = row
268+
# Add this function to our list
269+
functions_data.append(
270+
{
271+
"function_name": function_name,
272+
"class_name": class_name,
273+
"file_path": file_path,
274+
"module_name": module_name,
275+
"benchmark_function_name": benchmark_function_name,
276+
"benchmark_module_path": benchmark_module_path,
277+
"benchmark_line_number": benchmark_line_number,
278+
"function_properties": inspect_top_level_functions_or_methods(
279+
file_name=Path(file_path), function_or_method_name=function_name, class_name=class_name
280+
),
281+
}
282+
)
283+
284+
if not functions_data:
285+
logger.info(f"No benchmark test functions found in {benchmark_module_path}")
286+
continue
287+
# Generate the test code for this benchmark
288+
test_code = create_trace_replay_test_code(
289+
trace_file=trace_file_path.as_posix(),
290+
functions_data=functions_data,
291+
test_framework=test_framework,
292+
max_run_count=max_run_count,
293+
)
294+
test_code = isort.code(test_code)
295+
output_file = get_test_file_path(
296+
test_dir=Path(output_dir), function_name=benchmark_module_path, test_type="replay"
297+
)
298+
# Write test code to file, parents = true
299+
output_dir.mkdir(parents=True, exist_ok=True)
300+
output_file.write_text(test_code, "utf-8")
301+
count += 1
302+
finally:
303+
conn.close()
293304
except Exception as e:
294305
logger.info(f"Error generating replay tests: {e}")
295306

codeflash/lsp/beta.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,7 @@ def discover_function_tests(server: CodeflashLanguageServer, params: FunctionOpt
109109
fto = server.optimizer.current_function_being_optimized
110110
optimizable_funcs = {fto.file_path: [fto]}
111111

112-
devnull_writer = open(os.devnull, "w") # noqa
113-
with contextlib.redirect_stdout(devnull_writer):
112+
with open(os.devnull, "w") as devnull_writer, contextlib.redirect_stdout(devnull_writer): # noqa: PTH123
114113
function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
115114

116115
server.optimizer.discovered_tests = function_to_tests

0 commit comments

Comments
 (0)