@@ -35,6 +35,17 @@ def test_function_eligible_for_optimization() -> None:
35
35
assert len (functions_found [Path (f .name )]) == 0
36
36
37
37
38
+ # we want to trigger an error in the function discovery
39
+ function = """def test_invalid_code():"""
40
+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
41
+ f .write (function )
42
+ f .flush ()
43
+ functions_found = find_all_functions_in_file (Path (f .name ))
44
+ assert functions_found == {}
45
+
46
+
47
+
48
+
38
49
def test_find_top_level_function_or_method ():
39
50
with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
40
51
f .write (
@@ -82,6 +93,15 @@ def non_classmethod_function(cls, name):
82
93
).is_top_level
83
94
# needed because this will be traced with a class_name being passed
84
95
96
+ # we want to write invalid code to ensure that the function discovery does not crash
97
+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
98
+ f .write (
99
+ """def functionA():
100
+ """
101
+ )
102
+ f .flush ()
103
+ path_obj_name = Path (f .name )
104
+ assert not inspect_top_level_functions_or_methods (path_obj_name , "functionA" )
85
105
86
106
def test_class_method_discovery ():
87
107
with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
@@ -152,6 +172,133 @@ def functionA():
152
172
assert functions [file ][0 ].function_name == "functionA"
153
173
154
174
175
+ def test_nested_function ():
176
+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
177
+ f .write (
178
+ """
179
+ import copy
180
+
181
+ def propagate_attributes(
182
+ nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
183
+ ) -> dict[str, dict]:
184
+ modified_nodes = copy.deepcopy(nodes)
185
+
186
+ # Build an adjacency list for faster traversal
187
+ adjacency = {}
188
+ for edge in edges:
189
+ src = edge["source"]
190
+ tgt = edge["target"]
191
+ if src not in adjacency:
192
+ adjacency[src] = []
193
+ adjacency[src].append(tgt)
194
+
195
+ # Track visited nodes to avoid cycles
196
+ visited = set()
197
+
198
+ def traverse(node_id):
199
+ if node_id in visited:
200
+ return
201
+ visited.add(node_id)
202
+
203
+ # Propagate attribute from source node
204
+ if (
205
+ node_id != source_node_id
206
+ and source_node_id in modified_nodes
207
+ and attribute in modified_nodes[source_node_id]
208
+ ):
209
+ if node_id in modified_nodes:
210
+ modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
211
+ attribute
212
+ ]
213
+
214
+ # Continue propagation to neighbors
215
+ for neighbor in adjacency.get(node_id, []):
216
+ traverse(neighbor)
217
+
218
+ traverse(source_node_id)
219
+ return modified_nodes
220
+ """
221
+ )
222
+ f .flush ()
223
+ test_config = TestConfig (
224
+ tests_root = "tests" , project_root_path = "." , test_framework = "pytest" , tests_project_rootdir = Path ()
225
+ )
226
+ path_obj_name = Path (f .name )
227
+ functions , functions_count = get_functions_to_optimize (
228
+ optimize_all = None ,
229
+ replay_test = None ,
230
+ file = path_obj_name ,
231
+ test_cfg = test_config ,
232
+ only_get_this_function = None ,
233
+ ignore_paths = [Path ("/bruh/" )],
234
+ project_root = path_obj_name .parent ,
235
+ module_root = path_obj_name .parent ,
236
+ )
237
+
238
+ assert len (functions ) == 1
239
+ assert functions_count == 1
240
+
241
+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
242
+ f .write (
243
+ """
244
+ def outer_function():
245
+ def inner_function():
246
+ pass
247
+
248
+ return inner_function
249
+ """
250
+ )
251
+ f .flush ()
252
+ test_config = TestConfig (
253
+ tests_root = "tests" , project_root_path = "." , test_framework = "pytest" , tests_project_rootdir = Path ()
254
+ )
255
+ path_obj_name = Path (f .name )
256
+ functions , functions_count = get_functions_to_optimize (
257
+ optimize_all = None ,
258
+ replay_test = None ,
259
+ file = path_obj_name ,
260
+ test_cfg = test_config ,
261
+ only_get_this_function = None ,
262
+ ignore_paths = [Path ("/bruh/" )],
263
+ project_root = path_obj_name .parent ,
264
+ module_root = path_obj_name .parent ,
265
+ )
266
+
267
+ assert len (functions ) == 1
268
+ assert functions_count == 1
269
+
270
+ with tempfile .NamedTemporaryFile (mode = "w" , suffix = ".py" ) as f :
271
+ f .write (
272
+ """
273
+ def outer_function():
274
+ def inner_function():
275
+ pass
276
+
277
+ def another_inner_function():
278
+ pass
279
+ return inner_function, another_inner_function
280
+ """
281
+ )
282
+ f .flush ()
283
+ test_config = TestConfig (
284
+ tests_root = "tests" , project_root_path = "." , test_framework = "pytest" , tests_project_rootdir = Path ()
285
+ )
286
+ path_obj_name = Path (f .name )
287
+ functions , functions_count = get_functions_to_optimize (
288
+ optimize_all = None ,
289
+ replay_test = None ,
290
+ file = path_obj_name ,
291
+ test_cfg = test_config ,
292
+ only_get_this_function = None ,
293
+ ignore_paths = [Path ("/bruh/" )],
294
+ project_root = path_obj_name .parent ,
295
+ module_root = path_obj_name .parent ,
296
+ )
297
+
298
+ assert len (functions ) == 1
299
+ assert functions_count == 1
300
+
301
+
155
302
def test_filter_files_optimized ():
156
303
tests_root = Path ("tests" ).resolve ()
157
304
module_root = Path ().resolve ()
0 commit comments