66 find_all_functions_in_file ,
77 get_functions_to_optimize ,
88 inspect_top_level_functions_or_methods ,
9+ filter_functions
910)
1011from codeflash .verification .verification_utils import TestConfig
12+ from codeflash .code_utils .compat import codeflash_temp_dir
1113
1214
1315def test_function_eligible_for_optimization () -> None :
@@ -313,3 +315,69 @@ def test_filter_files_optimized():
313315 assert filter_files_optimized (file_path_same_level , tests_root , ignore_paths , module_root )
314316 assert filter_files_optimized (file_path_different_level , tests_root , ignore_paths , module_root )
315317 assert not filter_files_optimized (file_path_above_level , tests_root , ignore_paths , module_root )
318+
319+ def test_filter_functions ():
320+ with codeflash_temp_dir .joinpath ("test_get_functions_to_optimize.py" ).open ("w" ) as f :
321+ f .write (
322+ """
323+ import copy
324+
325+ def propagate_attributes(
326+ nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
327+ ) -> dict[str, dict]:
328+ modified_nodes = copy.deepcopy(nodes)
329+
330+ # Build an adjacency list for faster traversal
331+ adjacency = {}
332+ for edge in edges:
333+ src = edge["source"]
334+ tgt = edge["target"]
335+ if src not in adjacency:
336+ adjacency[src] = []
337+ adjacency[src].append(tgt)
338+
339+ # Track visited nodes to avoid cycles
340+ visited = set()
341+
342+ def traverse(node_id):
343+ if node_id in visited:
344+ return
345+ visited.add(node_id)
346+
347+ # Propagate attribute from source node
348+ if (
349+ node_id != source_node_id
350+ and source_node_id in modified_nodes
351+ and attribute in modified_nodes[source_node_id]
352+ ):
353+ if node_id in modified_nodes:
354+ modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
355+ attribute
356+ ]
357+
358+ # Continue propagation to neighbors
359+ for neighbor in adjacency.get(node_id, []):
360+ traverse(neighbor)
361+
362+ traverse(source_node_id)
363+ return modified_nodes
364+ """
365+ )
366+ f .flush ()
367+ test_config = TestConfig (
368+ tests_root = "tests" , project_root_path = "." , test_framework = "pytest" , tests_project_rootdir = Path ()
369+ )
370+
371+ file_path = codeflash_temp_dir .joinpath ("test_get_functions_to_optimize.py" )
372+ discovered = find_all_functions_in_file (file_path )
373+ modified_functions = {file_path : discovered [file_path ]}
374+ filtered , count = filter_functions (
375+ modified_functions ,
376+ tests_root = Path ("tests" ),
377+ ignore_paths = [],
378+ project_root = file_path .parent ,
379+ module_root = file_path .parent ,
380+ )
381+ function_names = [fn .function_name for fn in filtered .get (file_path , [])]
382+ assert "propagate_attributes" in function_names
383+ assert count == 1
0 commit comments