11import tempfile
22from pathlib import Path
3+ import os
4+ import unittest .mock
35
46from codeflash .discovery .functions_to_optimize import (
57 filter_files_optimized ,
68 find_all_functions_in_file ,
79 get_functions_to_optimize ,
810 inspect_top_level_functions_or_methods ,
11+ filter_functions ,
12+ get_all_files_and_functions
913)
1014from codeflash .verification .verification_utils import TestConfig
15+ from codeflash .code_utils .compat import codeflash_temp_dir
1116
1217
1318def test_function_eligible_for_optimization () -> None :
@@ -313,3 +318,241 @@ def test_filter_files_optimized():
313318 assert filter_files_optimized (file_path_same_level , tests_root , ignore_paths , module_root )
314319 assert filter_files_optimized (file_path_different_level , tests_root , ignore_paths , module_root )
315320 assert not filter_files_optimized (file_path_above_level , tests_root , ignore_paths , module_root )
321+
322+ def test_filter_functions ():
323+ with tempfile .TemporaryDirectory () as temp_dir_str :
324+ temp_dir = Path (temp_dir_str )
325+
326+ # Create a test file in the temporary directory
327+ test_file_path = temp_dir .joinpath ("test_get_functions_to_optimize.py" )
328+ with test_file_path .open ("w" ) as f :
329+ f .write (
330+ """
331+ import copy
332+
333+ def propagate_attributes(
334+ nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
335+ ) -> dict[str, dict]:
336+ modified_nodes = copy.deepcopy(nodes)
337+
338+ # Build an adjacency list for faster traversal
339+ adjacency = {}
340+ for edge in edges:
341+ src = edge["source"]
342+ tgt = edge["target"]
343+ if src not in adjacency:
344+ adjacency[src] = []
345+ adjacency[src].append(tgt)
346+
347+ # Track visited nodes to avoid cycles
348+ visited = set()
349+
350+ def traverse(node_id):
351+ if node_id in visited:
352+ return
353+ visited.add(node_id)
354+
355+ # Propagate attribute from source node
356+ if (
357+ node_id != source_node_id
358+ and source_node_id in modified_nodes
359+ and attribute in modified_nodes[source_node_id]
360+ ):
361+ if node_id in modified_nodes:
362+ modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
363+ attribute
364+ ]
365+
366+ # Continue propagation to neighbors
367+ for neighbor in adjacency.get(node_id, []):
368+ traverse(neighbor)
369+
370+ traverse(source_node_id)
371+ return modified_nodes
372+
373+ def vanilla_function():
374+ return "This is a vanilla function."
375+
376+ def not_in_checkpoint_function():
377+ return "This function is not in the checkpoint."
378+ """
379+ )
380+
381+
382+ discovered = find_all_functions_in_file (test_file_path )
383+ modified_functions = {test_file_path : discovered [test_file_path ]}
384+ filtered , count = filter_functions (
385+ modified_functions ,
386+ tests_root = Path ("tests" ),
387+ ignore_paths = [],
388+ project_root = temp_dir ,
389+ module_root = temp_dir ,
390+ )
391+ function_names = [fn .function_name for fn in filtered .get (test_file_path , [])]
392+ assert "propagate_attributes" in function_names
393+ assert count == 3
394+
395+ # Create a tests directory inside our temp directory
396+ tests_root_dir = temp_dir .joinpath ("tests" )
397+ tests_root_dir .mkdir (exist_ok = True )
398+
399+ test_file_path = tests_root_dir .joinpath ("test_functions.py" )
400+ with test_file_path .open ("w" ) as f :
401+ f .write (
402+ """
403+ def test_function_in_tests_dir():
404+ return "This function is in a test directory and should be filtered out."
405+ """
406+ )
407+
408+ discovered_test_file = find_all_functions_in_file (test_file_path )
409+ modified_functions_test = {test_file_path : discovered_test_file .get (test_file_path , [])}
410+
411+ filtered_test_file , count_test_file = filter_functions (
412+ modified_functions_test ,
413+ tests_root = tests_root_dir ,
414+ ignore_paths = [],
415+ project_root = temp_dir ,
416+ module_root = temp_dir ,
417+ )
418+
419+ assert not filtered_test_file
420+ assert count_test_file == 0
421+
422+ # Test ignored directory
423+ ignored_dir = temp_dir .joinpath ("ignored_dir" )
424+ ignored_dir .mkdir (exist_ok = True )
425+ ignored_file_path = ignored_dir .joinpath ("ignored_file.py" )
426+ with ignored_file_path .open ("w" ) as f :
427+ f .write ("def ignored_func(): return 1" )
428+
429+ discovered_ignored = find_all_functions_in_file (ignored_file_path )
430+ modified_functions_ignored = {ignored_file_path : discovered_ignored .get (ignored_file_path , [])}
431+
432+ filtered_ignored , count_ignored = filter_functions (
433+ modified_functions_ignored ,
434+ tests_root = Path ("tests" ),
435+ ignore_paths = [ignored_dir ],
436+ project_root = temp_dir ,
437+ module_root = temp_dir ,
438+ )
439+ assert not filtered_ignored
440+ assert count_ignored == 0
441+
442+ # Test submodule paths
443+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.ignored_submodule_paths" ,
444+ return_value = [str (temp_dir .joinpath ("submodule_dir" ))]):
445+ submodule_dir = temp_dir .joinpath ("submodule_dir" )
446+ submodule_dir .mkdir (exist_ok = True )
447+ submodule_file_path = submodule_dir .joinpath ("submodule_file.py" )
448+ with submodule_file_path .open ("w" ) as f :
449+ f .write ("def submodule_func(): return 1" )
450+
451+ discovered_submodule = find_all_functions_in_file (submodule_file_path )
452+ modified_functions_submodule = {submodule_file_path : discovered_submodule .get (submodule_file_path , [])}
453+
454+ filtered_submodule , count_submodule = filter_functions (
455+ modified_functions_submodule ,
456+ tests_root = Path ("tests" ),
457+ ignore_paths = [],
458+ project_root = temp_dir ,
459+ module_root = temp_dir ,
460+ )
461+ assert not filtered_submodule
462+ assert count_submodule == 0
463+
464+ # Test site packages
465+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages" ,
466+ return_value = True ):
467+ site_package_file_path = temp_dir .joinpath ("site_package_file.py" )
468+ with site_package_file_path .open ("w" ) as f :
469+ f .write ("def site_package_func(): return 1" )
470+
471+ discovered_site_package = find_all_functions_in_file (site_package_file_path )
472+ modified_functions_site_package = {site_package_file_path : discovered_site_package .get (site_package_file_path , [])}
473+
474+ filtered_site_package , count_site_package = filter_functions (
475+ modified_functions_site_package ,
476+ tests_root = Path ("tests" ),
477+ ignore_paths = [],
478+ project_root = temp_dir ,
479+ module_root = temp_dir ,
480+ )
481+ assert not filtered_site_package
482+ assert count_site_package == 0
483+
484+ # Test outside module root
485+ parent_dir = temp_dir .parent
486+ outside_module_root_path = parent_dir .joinpath ("outside_module_root_file.py" )
487+ try :
488+ with outside_module_root_path .open ("w" ) as f :
489+ f .write ("def func_outside_module_root(): return 1" )
490+
491+ discovered_outside_module = find_all_functions_in_file (outside_module_root_path )
492+ modified_functions_outside_module = {outside_module_root_path : discovered_outside_module .get (outside_module_root_path , [])}
493+
494+ filtered_outside_module , count_outside_module = filter_functions (
495+ modified_functions_outside_module ,
496+ tests_root = Path ("tests" ),
497+ ignore_paths = [],
498+ project_root = temp_dir ,
499+ module_root = temp_dir ,
500+ )
501+ assert not filtered_outside_module
502+ assert count_outside_module == 0
503+ finally :
504+ outside_module_root_path .unlink (missing_ok = True )
505+
506+ # Test invalid module name
507+ invalid_module_file_path = temp_dir .joinpath ("invalid-module-name.py" )
508+ with invalid_module_file_path .open ("w" ) as f :
509+ f .write ("def func_in_invalid_module(): return 1" )
510+
511+ discovered_invalid_module = find_all_functions_in_file (invalid_module_file_path )
512+ modified_functions_invalid_module = {invalid_module_file_path : discovered_invalid_module .get (invalid_module_file_path , [])}
513+
514+ filtered_invalid_module , count_invalid_module = filter_functions (
515+ modified_functions_invalid_module ,
516+ tests_root = Path ("tests" ),
517+ ignore_paths = [],
518+ project_root = temp_dir ,
519+ module_root = temp_dir ,
520+ )
521+ assert not filtered_invalid_module
522+ assert count_invalid_module == 0
523+
524+ original_file_path = temp_dir .joinpath ("test_get_functions_to_optimize.py" )
525+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.get_blocklisted_functions" ,
526+ return_value = {original_file_path .name : {"propagate_attributes" , "other_blocklisted_function" }}):
527+ filtered_funcs , count = filter_functions (
528+ modified_functions ,
529+ tests_root = Path ("tests" ),
530+ ignore_paths = [],
531+ project_root = temp_dir ,
532+ module_root = temp_dir ,
533+ )
534+ assert "propagate_attributes" not in [fn .function_name for fn in filtered_funcs .get (original_file_path , [])]
535+ assert count == 2
536+
537+ module_name = "test_get_functions_to_optimize"
538+ qualified_name_for_checkpoint = f"{ module_name } .propagate_attributes"
539+ other_qualified_name_for_checkpoint = f"{ module_name } .vanilla_function"
540+
541+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.get_blocklisted_functions" , return_value = {}):
542+ filtered_checkpoint , count_checkpoint = filter_functions (
543+ modified_functions ,
544+ tests_root = Path ("tests" ),
545+ ignore_paths = [],
546+ project_root = temp_dir ,
547+ module_root = temp_dir ,
548+ previous_checkpoint_functions = {qualified_name_for_checkpoint : {"status" : "optimized" }, other_qualified_name_for_checkpoint : {}}
549+ )
550+ assert filtered_checkpoint .get (original_file_path )
551+ assert count_checkpoint == 1
552+
553+ remaining_functions = [fn .function_name for fn in filtered_checkpoint .get (original_file_path , [])]
554+ assert "not_in_checkpoint_function" in remaining_functions
555+ assert "propagate_attributes" not in remaining_functions
556+ assert "vanilla_function" not in remaining_functions
557+ files_and_funcs = get_all_files_and_functions (module_root_path = temp_dir )
558+ assert len (files_and_funcs ) == 6
0 commit comments