Skip to content

Commit f7d8d6a

Browse files
committed
conn & windows
1 parent c5d73f8 commit f7d8d6a

File tree

2 files changed

+110
-78
lines changed

2 files changed

+110
-78
lines changed

tests/test_function_discovery.py

Lines changed: 105 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -21,40 +21,56 @@ def test_function_eligible_for_optimization() -> None:
2121
return a**2
2222
"""
2323
functions_found = {}
24-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
25-
f.write(function)
26-
f.flush()
27-
functions_found = find_all_functions_in_file(Path(f.name))
28-
assert functions_found[Path(f.name)][0].function_name == "test_function_eligible_for_optimization"
24+
with tempfile.TemporaryDirectory() as temp_dir:
25+
temp_dir_path = Path(temp_dir)
26+
file_path = temp_dir_path / "test_function.py"
27+
28+
with file_path.open("w") as f:
29+
f.write(function)
30+
31+
functions_found = find_all_functions_in_file(file_path)
32+
assert functions_found[file_path][0].function_name == "test_function_eligible_for_optimization"
2933

3034
# Has no return statement
3135
function = """def test_function_not_eligible_for_optimization():
3236
a = 5
3337
print(a)
3438
"""
3539
functions_found = {}
36-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
37-
f.write(function)
38-
f.flush()
39-
functions_found = find_all_functions_in_file(Path(f.name))
40-
assert len(functions_found[Path(f.name)]) == 0
40+
with tempfile.TemporaryDirectory() as temp_dir:
41+
temp_dir_path = Path(temp_dir)
42+
file_path = temp_dir_path / "test_function.py"
43+
44+
with file_path.open("w") as f:
45+
f.write(function)
46+
47+
functions_found = find_all_functions_in_file(file_path)
48+
assert len(functions_found[file_path]) == 0
4149

4250

4351
# we want to trigger an error in the function discovery
4452
function = """def test_invalid_code():"""
45-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
46-
f.write(function)
47-
f.flush()
48-
functions_found = find_all_functions_in_file(Path(f.name))
53+
with tempfile.TemporaryDirectory() as temp_dir:
54+
temp_dir_path = Path(temp_dir)
55+
file_path = temp_dir_path / "test_function.py"
56+
57+
with file_path.open("w") as f:
58+
f.write(function)
59+
60+
functions_found = find_all_functions_in_file(file_path)
4961
assert functions_found == {}
5062

5163

5264

5365

5466
def test_find_top_level_function_or_method():
55-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
56-
f.write(
57-
"""def functionA():
67+
with tempfile.TemporaryDirectory() as temp_dir:
68+
temp_dir_path = Path(temp_dir)
69+
file_path = temp_dir_path / "test_function.py"
70+
71+
with file_path.open("w") as f:
72+
f.write(
73+
"""def functionA():
5874
def functionB():
5975
return 5
6076
class E:
@@ -76,42 +92,48 @@ def functionE(cls, num):
7692
def non_classmethod_function(cls, name):
7793
return cls.name
7894
"""
79-
)
80-
f.flush()
81-
path_obj_name = Path(f.name)
82-
assert inspect_top_level_functions_or_methods(path_obj_name, "functionA").is_top_level
83-
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionB").is_top_level
84-
assert inspect_top_level_functions_or_methods(path_obj_name, "functionC", class_name="A").is_top_level
85-
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionD", class_name="A").is_top_level
86-
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionF", class_name="E").is_top_level
87-
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA").has_args
95+
)
96+
97+
assert inspect_top_level_functions_or_methods(file_path, "functionA").is_top_level
98+
assert not inspect_top_level_functions_or_methods(file_path, "functionB").is_top_level
99+
assert inspect_top_level_functions_or_methods(file_path, "functionC", class_name="A").is_top_level
100+
assert not inspect_top_level_functions_or_methods(file_path, "functionD", class_name="A").is_top_level
101+
assert not inspect_top_level_functions_or_methods(file_path, "functionF", class_name="E").is_top_level
102+
assert not inspect_top_level_functions_or_methods(file_path, "functionA").has_args
88103
staticmethod_func = inspect_top_level_functions_or_methods(
89-
path_obj_name, "handle_record_counts", class_name=None, line_no=15
104+
file_path, "handle_record_counts", class_name=None, line_no=15
90105
)
91106
assert staticmethod_func.is_staticmethod
92107
assert staticmethod_func.staticmethod_class_name == "AirbyteEntrypoint"
93108
assert inspect_top_level_functions_or_methods(
94-
path_obj_name, "functionE", class_name="AirbyteEntrypoint"
109+
file_path, "functionE", class_name="AirbyteEntrypoint"
95110
).is_classmethod
96111
assert not inspect_top_level_functions_or_methods(
97-
path_obj_name, "non_classmethod_function", class_name="AirbyteEntrypoint"
112+
file_path, "non_classmethod_function", class_name="AirbyteEntrypoint"
98113
).is_top_level
99114
# needed because this will be traced with a class_name being passed
100115

101116
# we want to write invalid code to ensure that the function discovery does not crash
102-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
103-
f.write(
104-
"""def functionA():
117+
with tempfile.TemporaryDirectory() as temp_dir:
118+
temp_dir_path = Path(temp_dir)
119+
file_path = temp_dir_path / "test_function.py"
120+
121+
with file_path.open("w") as f:
122+
f.write(
123+
"""def functionA():
105124
"""
106-
)
107-
f.flush()
108-
path_obj_name = Path(f.name)
109-
assert not inspect_top_level_functions_or_methods(path_obj_name, "functionA")
125+
)
126+
127+
assert not inspect_top_level_functions_or_methods(file_path, "functionA")
110128

111129
def test_class_method_discovery():
112-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
113-
f.write(
114-
"""class A:
130+
with tempfile.TemporaryDirectory() as temp_dir:
131+
temp_dir_path = Path(temp_dir)
132+
file_path = temp_dir_path / "test_function.py"
133+
134+
with file_path.open("w") as f:
135+
f.write(
136+
"""class A:
115137
def functionA():
116138
return True
117139
def functionB():
@@ -123,21 +145,20 @@ def functionB():
123145
return False
124146
def functionA():
125147
return True"""
126-
)
127-
f.flush()
148+
)
149+
128150
test_config = TestConfig(
129151
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
130152
)
131-
path_obj_name = Path(f.name)
132153
functions, functions_count, _ = get_functions_to_optimize(
133154
optimize_all=None,
134155
replay_test=None,
135-
file=path_obj_name,
156+
file=file_path,
136157
only_get_this_function="A.functionA",
137158
test_cfg=test_config,
138159
ignore_paths=[Path("/bruh/")],
139-
project_root=path_obj_name.parent,
140-
module_root=path_obj_name.parent,
160+
project_root=file_path.parent,
161+
module_root=file_path.parent,
141162
)
142163
assert len(functions) == 1
143164
for file in functions:
@@ -148,12 +169,12 @@ def functionA():
148169
functions, functions_count, _ = get_functions_to_optimize(
149170
optimize_all=None,
150171
replay_test=None,
151-
file=path_obj_name,
172+
file=file_path,
152173
only_get_this_function="X.functionA",
153174
test_cfg=test_config,
154175
ignore_paths=[Path("/bruh/")],
155-
project_root=path_obj_name.parent,
156-
module_root=path_obj_name.parent,
176+
project_root=file_path.parent,
177+
module_root=file_path.parent,
157178
)
158179
assert len(functions) == 1
159180
for file in functions:
@@ -164,12 +185,12 @@ def functionA():
164185
functions, functions_count, _ = get_functions_to_optimize(
165186
optimize_all=None,
166187
replay_test=None,
167-
file=path_obj_name,
188+
file=file_path,
168189
only_get_this_function="functionA",
169190
test_cfg=test_config,
170191
ignore_paths=[Path("/bruh/")],
171-
project_root=path_obj_name.parent,
172-
module_root=path_obj_name.parent,
192+
project_root=file_path.parent,
193+
module_root=file_path.parent,
173194
)
174195
assert len(functions) == 1
175196
for file in functions:
@@ -178,8 +199,12 @@ def functionA():
178199

179200

180201
def test_nested_function():
181-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
182-
f.write(
202+
with tempfile.TemporaryDirectory() as temp_dir:
203+
temp_dir_path = Path(temp_dir)
204+
file_path = temp_dir_path / "test_function.py"
205+
206+
with file_path.open("w") as f:
207+
f.write(
183208
"""
184209
import copy
185210
@@ -223,57 +248,63 @@ def traverse(node_id):
223248
traverse(source_node_id)
224249
return modified_nodes
225250
"""
226-
)
227-
f.flush()
251+
)
252+
228253
test_config = TestConfig(
229254
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
230255
)
231-
path_obj_name = Path(f.name)
232256
functions, functions_count, _ = get_functions_to_optimize(
233257
optimize_all=None,
234258
replay_test=None,
235-
file=path_obj_name,
259+
file=file_path,
236260
test_cfg=test_config,
237261
only_get_this_function=None,
238262
ignore_paths=[Path("/bruh/")],
239-
project_root=path_obj_name.parent,
240-
module_root=path_obj_name.parent,
263+
project_root=file_path.parent,
264+
module_root=file_path.parent,
241265
)
242266

243267
assert len(functions) == 1
244268
assert functions_count == 1
245269

246-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
247-
f.write(
270+
with tempfile.TemporaryDirectory() as temp_dir:
271+
temp_dir_path = Path(temp_dir)
272+
file_path = temp_dir_path / "test_function.py"
273+
274+
with file_path.open("w") as f:
275+
f.write(
248276
"""
249277
def outer_function():
250278
def inner_function():
251279
pass
252280
253281
return inner_function
254282
"""
255-
)
256-
f.flush()
283+
)
284+
257285
test_config = TestConfig(
258286
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
259287
)
260-
path_obj_name = Path(f.name)
261288
functions, functions_count, _ = get_functions_to_optimize(
262289
optimize_all=None,
263290
replay_test=None,
264-
file=path_obj_name,
291+
file=file_path,
265292
test_cfg=test_config,
266293
only_get_this_function=None,
267294
ignore_paths=[Path("/bruh/")],
268-
project_root=path_obj_name.parent,
269-
module_root=path_obj_name.parent,
295+
project_root=file_path.parent,
296+
module_root=file_path.parent,
270297
)
271298

272299
assert len(functions) == 1
273300
assert functions_count == 1
274301

275-
with tempfile.NamedTemporaryFile(mode="w", suffix=".py") as f:
276-
f.write(
302+
with tempfile.TemporaryDirectory() as temp_dir:
303+
temp_dir_path = Path(temp_dir)
304+
file_path = temp_dir_path / "test_function.py"
305+
306+
with file_path.open("w") as f:
307+
f.write(
277308
"""
278309
def outer_function():
279310
def inner_function():
@@ -283,21 +314,20 @@ def another_inner_function():
283314
pass
284315
return inner_function, another_inner_function
285316
"""
286-
)
287-
f.flush()
317+
)
318+
288319
test_config = TestConfig(
289320
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
290321
)
291-
path_obj_name = Path(f.name)
292322
functions, functions_count, _ = get_functions_to_optimize(
293323
optimize_all=None,
294324
replay_test=None,
295-
file=path_obj_name,
325+
file=file_path,
296326
test_cfg=test_config,
297327
only_get_this_function=None,
298328
ignore_paths=[Path("/bruh/")],
299-
project_root=path_obj_name.parent,
300-
module_root=path_obj_name.parent,
329+
project_root=file_path.parent,
330+
module_root=file_path.parent,
301331
)
302332

303333
assert len(functions) == 1

tests/test_trace_benchmarks.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ def test_trace_multithreaded_benchmark() -> None:
196196
"SELECT function_name, class_name, module_name, file_path, benchmark_function_name, benchmark_module_path, benchmark_line_number FROM benchmark_function_timings ORDER BY benchmark_module_path, benchmark_function_name, function_name")
197197
function_calls = cursor.fetchall()
198198

199+
conn.close()
200+
199201
# Assert the length of function calls
200202
assert len(function_calls) == 10, f"Expected 10 function calls, but got {len(function_calls)}"
201203
function_benchmark_timings = codeflash_benchmark_plugin.get_function_benchmark_timings(output_file)
@@ -204,9 +206,9 @@ def test_trace_multithreaded_benchmark() -> None:
204206
assert "code_to_optimize.bubble_sort_codeflash_trace.sorter" in function_to_results
205207

206208
test_name, total_time, function_time, percent = function_to_results["code_to_optimize.bubble_sort_codeflash_trace.sorter"][0]
207-
assert total_time > 0.0
208-
assert function_time > 0.0
209-
assert percent > 0.0
209+
assert total_time >= 0.0
210+
assert function_time >= 0.0
211+
assert percent >= 0.0
210212

211213
bubble_sort_path = (project_root / "bubble_sort_codeflash_trace.py").as_posix()
212214
# Expected function calls

0 commit comments

Comments
 (0)