11import tempfile
22from pathlib import Path
3+ import os
4+ import unittest .mock
35
46from codeflash .discovery .functions_to_optimize import (
57 filter_files_optimized ,
@@ -361,6 +363,12 @@ def traverse(node_id):
361363
362364 traverse(source_node_id)
363365 return modified_nodes
366+
367+ def vanilla_function():
368+ return "This is a vanilla function."
369+
370+ def not_in_checkpoint_function():
371+ return "This function is not in the checkpoint."
364372"""
365373 )
366374 f .flush ()
@@ -380,4 +388,151 @@ def traverse(node_id):
380388 )
381389 function_names = [fn .function_name for fn in filtered .get (file_path , [])]
382390 assert "propagate_attributes" in function_names
383- assert count == 1
391+ assert count == 3
392+
393+ tests_root_dir = codeflash_temp_dir .joinpath ("tests" )
394+ tests_root_dir .mkdir (exist_ok = True )
395+
396+ test_file_path = tests_root_dir .joinpath ("test_functions.py" )
397+ with test_file_path .open ("w" ) as f :
398+ f .write (
399+ """
400+ def test_function_in_tests_dir():
401+ return "This function is in a test directory and should be filtered out."
402+ """
403+ )
404+
405+ discovered_test_file = find_all_functions_in_file (test_file_path )
406+ modified_functions_test = {test_file_path : discovered_test_file .get (test_file_path , [])}
407+
408+ filtered_test_file , count_test_file = filter_functions (
409+ modified_functions_test ,
410+ tests_root = tests_root_dir ,
411+ ignore_paths = [],
412+ project_root = codeflash_temp_dir ,
413+ module_root = codeflash_temp_dir ,
414+ )
415+
416+ assert not filtered_test_file
417+ assert count_test_file == 0
418+
419+ with codeflash_temp_dir .joinpath ("ignored_dir" ).open ("w" ) as f :
420+ f .write ("def ignored_func(): return 1" )
421+
422+ ignored_file_path = codeflash_temp_dir .joinpath ("ignored_dir" )
423+ discovered_ignored = find_all_functions_in_file (ignored_file_path )
424+ modified_functions_ignored = {ignored_file_path : discovered_ignored .get (ignored_file_path , [])}
425+
426+ filtered_ignored , count_ignored = filter_functions (
427+ modified_functions_ignored ,
428+ tests_root = Path ("tests" ),
429+ ignore_paths = [ignored_file_path .parent ],
430+ project_root = file_path .parent ,
431+ module_root = file_path .parent ,
432+ )
433+ assert not filtered_ignored
434+ assert count_ignored == 0
435+
436+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.ignored_submodule_paths" , return_value = [str (codeflash_temp_dir .joinpath ("submodule_dir" ))]):
437+ with codeflash_temp_dir .joinpath ("submodule_dir" ).open ("w" ) as f :
438+ f .write ("def submodule_func(): return 1" )
439+
440+ submodule_file_path = codeflash_temp_dir .joinpath ("submodule_dir" )
441+ discovered_submodule = find_all_functions_in_file (submodule_file_path )
442+ modified_functions_submodule = {submodule_file_path : discovered_submodule .get (submodule_file_path , [])}
443+
444+ filtered_submodule , count_submodule = filter_functions (
445+ modified_functions_submodule ,
446+ tests_root = Path ("tests" ),
447+ ignore_paths = [],
448+ project_root = file_path .parent ,
449+ module_root = file_path .parent ,
450+ )
451+ assert not filtered_submodule
452+ assert count_submodule == 0
453+
454+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages" , return_value = True ):
455+ with codeflash_temp_dir .joinpath ("site_package_file.py" ).open ("w" ) as f :
456+ f .write ("def site_package_func(): return 1" )
457+
458+ site_package_file_path = codeflash_temp_dir .joinpath ("site_package_file.py" )
459+ discovered_site_package = find_all_functions_in_file (site_package_file_path )
460+ modified_functions_site_package = {site_package_file_path : discovered_site_package .get (site_package_file_path , [])}
461+
462+ filtered_site_package , count_site_package = filter_functions (
463+ modified_functions_site_package ,
464+ tests_root = Path ("tests" ),
465+ ignore_paths = [],
466+ project_root = file_path .parent ,
467+ module_root = file_path .parent ,
468+ )
469+ assert not filtered_site_package
470+ assert count_site_package == 0
471+
472+ outside_module_root_path = codeflash_temp_dir .parent .joinpath ("outside_module_root_file.py" )
473+ with outside_module_root_path .open ("w" ) as f :
474+ f .write ("def func_outside_module_root(): return 1" )
475+
476+ discovered_outside_module = find_all_functions_in_file (outside_module_root_path )
477+ modified_functions_outside_module = {outside_module_root_path : discovered_outside_module .get (outside_module_root_path , [])}
478+
479+ filtered_outside_module , count_outside_module = filter_functions (
480+ modified_functions_outside_module ,
481+ tests_root = Path ("tests" ),
482+ ignore_paths = [],
483+ project_root = file_path .parent ,
484+ module_root = file_path .parent ,
485+ )
486+ assert not filtered_outside_module
487+ assert count_outside_module == 0
488+ os .remove (outside_module_root_path )
489+
490+ invalid_module_file_path = codeflash_temp_dir .joinpath ("invalid-module-name.py" )
491+ with invalid_module_file_path .open ("w" ) as f :
492+ f .write ("def func_in_invalid_module(): return 1" )
493+
494+ discovered_invalid_module = find_all_functions_in_file (invalid_module_file_path )
495+ modified_functions_invalid_module = {invalid_module_file_path : discovered_invalid_module .get (invalid_module_file_path , [])}
496+
497+ filtered_invalid_module , count_invalid_module = filter_functions (
498+ modified_functions_invalid_module ,
499+ tests_root = Path ("tests" ),
500+ ignore_paths = [],
501+ project_root = file_path .parent ,
502+ module_root = file_path .parent ,
503+ )
504+ assert not filtered_invalid_module
505+ assert count_invalid_module == 0
506+
507+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.get_blocklisted_functions" , return_value = {file_path .name : {"propagate_attributes" , "other_blocklisted_function" }}):
508+ filtered_blocklisted , count_blocklisted = filter_functions (
509+ modified_functions ,
510+ tests_root = Path ("tests" ),
511+ ignore_paths = [],
512+ project_root = file_path .parent ,
513+ module_root = file_path .parent ,
514+ )
515+ assert "propagate_attributes" in [fn .function_name for fn in filtered_blocklisted .get (file_path , [])]
516+ assert count_blocklisted == 1
517+
518+
519+ module_name = "test_get_functions_to_optimize"
520+ qualified_name_for_checkpoint = f"{ module_name } .propagate_attributes"
521+ other_qualified_name_for_checkpoint = f"{ module_name } .vanilla_function"
522+
523+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.get_blocklisted_functions" , return_value = {}):
524+ filtered_checkpoint , count_checkpoint = filter_functions (
525+ modified_functions ,
526+ tests_root = Path ("tests" ),
527+ ignore_paths = [],
528+ project_root = file_path .parent ,
529+ module_root = file_path .parent ,
530+ previous_checkpoint_functions = {qualified_name_for_checkpoint : {"status" : "optimized" }, other_qualified_name_for_checkpoint : {}}
531+ )
532+ assert filtered_checkpoint .get (file_path )
533+ assert count_checkpoint == 1
534+
535+ remaining_functions = [fn .function_name for fn in filtered_checkpoint .get (file_path , [])]
536+ assert "not_in_checkpoint_function" in remaining_functions
537+ assert "propagate_attributes" not in remaining_functions
538+ assert "vanilla_function" not in remaining_functions
0 commit comments