88 find_all_functions_in_file ,
99 get_functions_to_optimize ,
1010 inspect_top_level_functions_or_methods ,
11- filter_functions
11+ filter_functions ,
12+ get_all_files_and_functions
1213)
1314from codeflash .verification .verification_utils import TestConfig
1415from codeflash .code_utils .compat import codeflash_temp_dir
@@ -319,8 +320,13 @@ def test_filter_files_optimized():
319320 assert not filter_files_optimized (file_path_above_level , tests_root , ignore_paths , module_root )
320321
321322def test_filter_functions ():
322- with codeflash_temp_dir .joinpath ("test_get_functions_to_optimize.py" ).open ("w" ) as f :
323- f .write (
323+ with tempfile .TemporaryDirectory () as temp_dir_str :
324+ temp_dir = Path (temp_dir_str )
325+
326+ # Create a test file in the temporary directory
327+ test_file_path = temp_dir .joinpath ("test_get_functions_to_optimize.py" )
328+ with test_file_path .open ("w" ) as f :
329+ f .write (
324330"""
325331import copy
326332
@@ -370,169 +376,183 @@ def vanilla_function():
370376def not_in_checkpoint_function():
371377 return "This function is not in the checkpoint."
372378"""
373- )
374- f .flush ()
375- test_config = TestConfig (
376- tests_root = "tests" , project_root_path = "." , test_framework = "pytest" , tests_project_rootdir = Path ()
377- )
379+ )
380+
378381
379- file_path = codeflash_temp_dir .joinpath ("test_get_functions_to_optimize.py" )
380- discovered = find_all_functions_in_file (file_path )
381- modified_functions = {file_path : discovered [file_path ]}
382+ discovered = find_all_functions_in_file (test_file_path )
383+ modified_functions = {test_file_path : discovered [test_file_path ]}
382384 filtered , count = filter_functions (
383385 modified_functions ,
384386 tests_root = Path ("tests" ),
385387 ignore_paths = [],
386- project_root = file_path . parent ,
387- module_root = file_path . parent ,
388+ project_root = temp_dir ,
389+ module_root = temp_dir ,
388390 )
389- function_names = [fn .function_name for fn in filtered .get (file_path , [])]
391+ function_names = [fn .function_name for fn in filtered .get (test_file_path , [])]
390392 assert "propagate_attributes" in function_names
391393 assert count == 3
392394
393- tests_root_dir = codeflash_temp_dir .joinpath ("tests" )
394- tests_root_dir .mkdir (exist_ok = True )
395-
396- test_file_path = tests_root_dir .joinpath ("test_functions.py" )
397- with test_file_path .open ("w" ) as f :
398- f .write (
395+ # Create a tests directory inside our temp directory
396+ tests_root_dir = temp_dir .joinpath ("tests" )
397+ tests_root_dir .mkdir (exist_ok = True )
398+
399+ test_file_path = tests_root_dir .joinpath ("test_functions.py" )
400+ with test_file_path .open ("w" ) as f :
401+ f .write (
399402"""
400403def test_function_in_tests_dir():
401404 return "This function is in a test directory and should be filtered out."
402405"""
403- )
404-
405- discovered_test_file = find_all_functions_in_file (test_file_path )
406- modified_functions_test = {test_file_path : discovered_test_file .get (test_file_path , [])}
407-
408- filtered_test_file , count_test_file = filter_functions (
409- modified_functions_test ,
410- tests_root = tests_root_dir ,
411- ignore_paths = [],
412- project_root = codeflash_temp_dir ,
413- module_root = codeflash_temp_dir ,
414- )
415-
416- assert not filtered_test_file
417- assert count_test_file == 0
418-
419- with codeflash_temp_dir .joinpath ("ignored_dir" ).open ("w" ) as f :
420- f .write ("def ignored_func(): return 1" )
421-
422- ignored_file_path = codeflash_temp_dir .joinpath ("ignored_dir" )
423- discovered_ignored = find_all_functions_in_file (ignored_file_path )
424- modified_functions_ignored = {ignored_file_path : discovered_ignored .get (ignored_file_path , [])}
425-
426- filtered_ignored , count_ignored = filter_functions (
427- modified_functions_ignored ,
428- tests_root = Path ("tests" ),
429- ignore_paths = [ignored_file_path .parent ],
430- project_root = file_path .parent ,
431- module_root = file_path .parent ,
432- )
433- assert not filtered_ignored
434- assert count_ignored == 0
435-
436- with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.ignored_submodule_paths" , return_value = [str (codeflash_temp_dir .joinpath ("submodule_dir" ))]):
437- with codeflash_temp_dir .joinpath ("submodule_dir" ).open ("w" ) as f :
438- f .write ("def submodule_func(): return 1" )
406+ )
439407
440- submodule_file_path = codeflash_temp_dir .joinpath ("submodule_dir" )
441- discovered_submodule = find_all_functions_in_file (submodule_file_path )
442- modified_functions_submodule = {submodule_file_path : discovered_submodule .get (submodule_file_path , [])}
408+ discovered_test_file = find_all_functions_in_file (test_file_path )
409+ modified_functions_test = {test_file_path : discovered_test_file .get (test_file_path , [])}
443410
444- filtered_submodule , count_submodule = filter_functions (
445- modified_functions_submodule ,
446- tests_root = Path ( "tests" ) ,
411+ filtered_test_file , count_test_file = filter_functions (
412+ modified_functions_test ,
413+ tests_root = tests_root_dir ,
447414 ignore_paths = [],
448- project_root = file_path . parent ,
449- module_root = file_path . parent ,
415+ project_root = temp_dir ,
416+ module_root = temp_dir ,
450417 )
451- assert not filtered_submodule
452- assert count_submodule == 0
453-
454- with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages" , return_value = True ):
455- with codeflash_temp_dir .joinpath ("site_package_file.py" ).open ("w" ) as f :
456- f .write ("def site_package_func(): return 1" )
457-
458- site_package_file_path = codeflash_temp_dir .joinpath ("site_package_file.py" )
459- discovered_site_package = find_all_functions_in_file (site_package_file_path )
460- modified_functions_site_package = {site_package_file_path : discovered_site_package .get (site_package_file_path , [])}
418+
419+ assert not filtered_test_file
420+ assert count_test_file == 0
421+
422+ # Test ignored directory
423+ ignored_dir = temp_dir .joinpath ("ignored_dir" )
424+ ignored_dir .mkdir (exist_ok = True )
425+ ignored_file_path = ignored_dir .joinpath ("ignored_file.py" )
426+ with ignored_file_path .open ("w" ) as f :
427+ f .write ("def ignored_func(): return 1" )
428+
429+ discovered_ignored = find_all_functions_in_file (ignored_file_path )
430+ modified_functions_ignored = {ignored_file_path : discovered_ignored .get (ignored_file_path , [])}
461431
462- filtered_site_package , count_site_package = filter_functions (
463- modified_functions_site_package ,
432+ filtered_ignored , count_ignored = filter_functions (
433+ modified_functions_ignored ,
464434 tests_root = Path ("tests" ),
465- ignore_paths = [],
466- project_root = file_path . parent ,
467- module_root = file_path . parent ,
435+ ignore_paths = [ignored_dir ],
436+ project_root = temp_dir ,
437+ module_root = temp_dir ,
468438 )
469- assert not filtered_site_package
470- assert count_site_package == 0
471-
472- outside_module_root_path = codeflash_temp_dir .parent .joinpath ("outside_module_root_file.py" )
473- with outside_module_root_path .open ("w" ) as f :
474- f .write ("def func_outside_module_root(): return 1" )
475-
476- discovered_outside_module = find_all_functions_in_file (outside_module_root_path )
477- modified_functions_outside_module = {outside_module_root_path : discovered_outside_module .get (outside_module_root_path , [])}
478-
479- filtered_outside_module , count_outside_module = filter_functions (
480- modified_functions_outside_module ,
481- tests_root = Path ("tests" ),
482- ignore_paths = [],
483- project_root = file_path .parent ,
484- module_root = file_path .parent ,
485- )
486- assert not filtered_outside_module
487- assert count_outside_module == 0
488- os .remove (outside_module_root_path )
489-
490- invalid_module_file_path = codeflash_temp_dir .joinpath ("invalid-module-name.py" )
491- with invalid_module_file_path .open ("w" ) as f :
492- f .write ("def func_in_invalid_module(): return 1" )
493-
494- discovered_invalid_module = find_all_functions_in_file (invalid_module_file_path )
495- modified_functions_invalid_module = {invalid_module_file_path : discovered_invalid_module .get (invalid_module_file_path , [])}
496-
497- filtered_invalid_module , count_invalid_module = filter_functions (
498- modified_functions_invalid_module ,
499- tests_root = Path ("tests" ),
500- ignore_paths = [],
501- project_root = file_path .parent ,
502- module_root = file_path .parent ,
503- )
504- assert not filtered_invalid_module
505- assert count_invalid_module == 0
506-
507- with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.get_blocklisted_functions" , return_value = {file_path .name : {"propagate_attributes" , "other_blocklisted_function" }}):
508- filtered_funcs , count = filter_functions (
509- modified_functions ,
510- tests_root = Path ("tests" ),
511- ignore_paths = [],
512- project_root = file_path .parent ,
513- module_root = file_path .parent ,
514- )
515- assert "propagate_attributes" not in [fn .function_name for fn in filtered_funcs .get (file_path , [])]
516- assert count == 2
517-
518-
519- module_name = "test_get_functions_to_optimize"
520- qualified_name_for_checkpoint = f"{ module_name } .propagate_attributes"
521- other_qualified_name_for_checkpoint = f"{ module_name } .vanilla_function"
522-
523- with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.get_blocklisted_functions" , return_value = {}):
524- filtered_checkpoint , count_checkpoint = filter_functions (
525- modified_functions ,
439+ assert not filtered_ignored
440+ assert count_ignored == 0
441+
442+ # Test submodule paths
443+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.ignored_submodule_paths" ,
444+ return_value = [str (temp_dir .joinpath ("submodule_dir" ))]):
445+ submodule_dir = temp_dir .joinpath ("submodule_dir" )
446+ submodule_dir .mkdir (exist_ok = True )
447+ submodule_file_path = submodule_dir .joinpath ("submodule_file.py" )
448+ with submodule_file_path .open ("w" ) as f :
449+ f .write ("def submodule_func(): return 1" )
450+
451+ discovered_submodule = find_all_functions_in_file (submodule_file_path )
452+ modified_functions_submodule = {submodule_file_path : discovered_submodule .get (submodule_file_path , [])}
453+
454+ filtered_submodule , count_submodule = filter_functions (
455+ modified_functions_submodule ,
456+ tests_root = Path ("tests" ),
457+ ignore_paths = [],
458+ project_root = temp_dir ,
459+ module_root = temp_dir ,
460+ )
461+ assert not filtered_submodule
462+ assert count_submodule == 0
463+
464+ # Test site packages
465+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages" ,
466+ return_value = True ):
467+ site_package_file_path = temp_dir .joinpath ("site_package_file.py" )
468+ with site_package_file_path .open ("w" ) as f :
469+ f .write ("def site_package_func(): return 1" )
470+
471+ discovered_site_package = find_all_functions_in_file (site_package_file_path )
472+ modified_functions_site_package = {site_package_file_path : discovered_site_package .get (site_package_file_path , [])}
473+
474+ filtered_site_package , count_site_package = filter_functions (
475+ modified_functions_site_package ,
476+ tests_root = Path ("tests" ),
477+ ignore_paths = [],
478+ project_root = temp_dir ,
479+ module_root = temp_dir ,
480+ )
481+ assert not filtered_site_package
482+ assert count_site_package == 0
483+
484+ # Test outside module root
485+ parent_dir = temp_dir .parent
486+ outside_module_root_path = parent_dir .joinpath ("outside_module_root_file.py" )
487+ try :
488+ with outside_module_root_path .open ("w" ) as f :
489+ f .write ("def func_outside_module_root(): return 1" )
490+
491+ discovered_outside_module = find_all_functions_in_file (outside_module_root_path )
492+ modified_functions_outside_module = {outside_module_root_path : discovered_outside_module .get (outside_module_root_path , [])}
493+
494+ filtered_outside_module , count_outside_module = filter_functions (
495+ modified_functions_outside_module ,
496+ tests_root = Path ("tests" ),
497+ ignore_paths = [],
498+ project_root = temp_dir ,
499+ module_root = temp_dir ,
500+ )
501+ assert not filtered_outside_module
502+ assert count_outside_module == 0
503+ finally :
504+ outside_module_root_path .unlink (missing_ok = True )
505+
506+ # Test invalid module name
507+ invalid_module_file_path = temp_dir .joinpath ("invalid-module-name.py" )
508+ with invalid_module_file_path .open ("w" ) as f :
509+ f .write ("def func_in_invalid_module(): return 1" )
510+
511+ discovered_invalid_module = find_all_functions_in_file (invalid_module_file_path )
512+ modified_functions_invalid_module = {invalid_module_file_path : discovered_invalid_module .get (invalid_module_file_path , [])}
513+
514+ filtered_invalid_module , count_invalid_module = filter_functions (
515+ modified_functions_invalid_module ,
526516 tests_root = Path ("tests" ),
527517 ignore_paths = [],
528- project_root = file_path .parent ,
529- module_root = file_path .parent ,
530- previous_checkpoint_functions = {qualified_name_for_checkpoint : {"status" : "optimized" }, other_qualified_name_for_checkpoint : {}}
518+ project_root = temp_dir ,
519+ module_root = temp_dir ,
531520 )
532- assert filtered_checkpoint .get (file_path )
533- assert count_checkpoint == 1
534-
535- remaining_functions = [fn .function_name for fn in filtered_checkpoint .get (file_path , [])]
536- assert "not_in_checkpoint_function" in remaining_functions
537- assert "propagate_attributes" not in remaining_functions
538- assert "vanilla_function" not in remaining_functions
521+ assert not filtered_invalid_module
522+ assert count_invalid_module == 0
523+
524+ original_file_path = temp_dir .joinpath ("test_get_functions_to_optimize.py" )
525+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.get_blocklisted_functions" ,
526+ return_value = {original_file_path .name : {"propagate_attributes" , "other_blocklisted_function" }}):
527+ filtered_funcs , count = filter_functions (
528+ modified_functions ,
529+ tests_root = Path ("tests" ),
530+ ignore_paths = [],
531+ project_root = temp_dir ,
532+ module_root = temp_dir ,
533+ )
534+ assert "propagate_attributes" not in [fn .function_name for fn in filtered_funcs .get (original_file_path , [])]
535+ assert count == 2
536+
537+ module_name = "test_get_functions_to_optimize"
538+ qualified_name_for_checkpoint = f"{ module_name } .propagate_attributes"
539+ other_qualified_name_for_checkpoint = f"{ module_name } .vanilla_function"
540+
541+ with unittest .mock .patch ("codeflash.discovery.functions_to_optimize.get_blocklisted_functions" , return_value = {}):
542+ filtered_checkpoint , count_checkpoint = filter_functions (
543+ modified_functions ,
544+ tests_root = Path ("tests" ),
545+ ignore_paths = [],
546+ project_root = temp_dir ,
547+ module_root = temp_dir ,
548+ previous_checkpoint_functions = {qualified_name_for_checkpoint : {"status" : "optimized" }, other_qualified_name_for_checkpoint : {}}
549+ )
550+ assert filtered_checkpoint .get (original_file_path )
551+ assert count_checkpoint == 1
552+
553+ remaining_functions = [fn .function_name for fn in filtered_checkpoint .get (original_file_path , [])]
554+ assert "not_in_checkpoint_function" in remaining_functions
555+ assert "propagate_attributes" not in remaining_functions
556+ assert "vanilla_function" not in remaining_functions
557+ files_and_funcs = get_all_files_and_functions (module_root_path = temp_dir )
558+ assert len (files_and_funcs ) == 6
0 commit comments