Skip to content

Commit cbfc714

Browse files
committed
cover get_all_files_and_functions
1 parent a6c456e commit cbfc714

File tree

1 file changed

+166
-146
lines changed

1 file changed

+166
-146
lines changed

tests/test_function_discovery.py

Lines changed: 166 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
find_all_functions_in_file,
99
get_functions_to_optimize,
1010
inspect_top_level_functions_or_methods,
11-
filter_functions
11+
filter_functions,
12+
get_all_files_and_functions
1213
)
1314
from codeflash.verification.verification_utils import TestConfig
1415
from codeflash.code_utils.compat import codeflash_temp_dir
@@ -319,8 +320,13 @@ def test_filter_files_optimized():
319320
assert not filter_files_optimized(file_path_above_level, tests_root, ignore_paths, module_root)
320321

321322
def test_filter_functions():
322-
with codeflash_temp_dir.joinpath("test_get_functions_to_optimize.py").open("w") as f:
323-
f.write(
323+
with tempfile.TemporaryDirectory() as temp_dir_str:
324+
temp_dir = Path(temp_dir_str)
325+
326+
# Create a test file in the temporary directory
327+
test_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
328+
with test_file_path.open("w") as f:
329+
f.write(
324330
"""
325331
import copy
326332
@@ -370,169 +376,183 @@ def vanilla_function():
370376
def not_in_checkpoint_function():
371377
return "This function is not in the checkpoint."
372378
"""
373-
)
374-
f.flush()
375-
test_config = TestConfig(
376-
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
377-
)
379+
)
380+
378381

379-
file_path = codeflash_temp_dir.joinpath("test_get_functions_to_optimize.py")
380-
discovered = find_all_functions_in_file(file_path)
381-
modified_functions = {file_path: discovered[file_path]}
382+
discovered = find_all_functions_in_file(test_file_path)
383+
modified_functions = {test_file_path: discovered[test_file_path]}
382384
filtered, count = filter_functions(
383385
modified_functions,
384386
tests_root=Path("tests"),
385387
ignore_paths=[],
386-
project_root=file_path.parent,
387-
module_root=file_path.parent,
388+
project_root=temp_dir,
389+
module_root=temp_dir,
388390
)
389-
function_names = [fn.function_name for fn in filtered.get(file_path, [])]
391+
function_names = [fn.function_name for fn in filtered.get(test_file_path, [])]
390392
assert "propagate_attributes" in function_names
391393
assert count == 3
392394

393-
tests_root_dir = codeflash_temp_dir.joinpath("tests")
394-
tests_root_dir.mkdir(exist_ok=True)
395-
396-
test_file_path = tests_root_dir.joinpath("test_functions.py")
397-
with test_file_path.open("w") as f:
398-
f.write(
395+
# Create a tests directory inside our temp directory
396+
tests_root_dir = temp_dir.joinpath("tests")
397+
tests_root_dir.mkdir(exist_ok=True)
398+
399+
test_file_path = tests_root_dir.joinpath("test_functions.py")
400+
with test_file_path.open("w") as f:
401+
f.write(
399402
"""
400403
def test_function_in_tests_dir():
401404
return "This function is in a test directory and should be filtered out."
402405
"""
403-
)
404-
405-
discovered_test_file = find_all_functions_in_file(test_file_path)
406-
modified_functions_test = {test_file_path: discovered_test_file.get(test_file_path, [])}
407-
408-
filtered_test_file, count_test_file = filter_functions(
409-
modified_functions_test,
410-
tests_root=tests_root_dir,
411-
ignore_paths=[],
412-
project_root=codeflash_temp_dir,
413-
module_root=codeflash_temp_dir,
414-
)
415-
416-
assert not filtered_test_file
417-
assert count_test_file == 0
418-
419-
with codeflash_temp_dir.joinpath("ignored_dir").open("w") as f:
420-
f.write("def ignored_func(): return 1")
421-
422-
ignored_file_path = codeflash_temp_dir.joinpath("ignored_dir")
423-
discovered_ignored = find_all_functions_in_file(ignored_file_path)
424-
modified_functions_ignored = {ignored_file_path: discovered_ignored.get(ignored_file_path, [])}
425-
426-
filtered_ignored, count_ignored = filter_functions(
427-
modified_functions_ignored,
428-
tests_root=Path("tests"),
429-
ignore_paths=[ignored_file_path.parent],
430-
project_root=file_path.parent,
431-
module_root=file_path.parent,
432-
)
433-
assert not filtered_ignored
434-
assert count_ignored == 0
435-
436-
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.ignored_submodule_paths", return_value=[str(codeflash_temp_dir.joinpath("submodule_dir"))]):
437-
with codeflash_temp_dir.joinpath("submodule_dir").open("w") as f:
438-
f.write("def submodule_func(): return 1")
406+
)
439407

440-
submodule_file_path = codeflash_temp_dir.joinpath("submodule_dir")
441-
discovered_submodule = find_all_functions_in_file(submodule_file_path)
442-
modified_functions_submodule = {submodule_file_path: discovered_submodule.get(submodule_file_path, [])}
408+
discovered_test_file = find_all_functions_in_file(test_file_path)
409+
modified_functions_test = {test_file_path: discovered_test_file.get(test_file_path, [])}
443410

444-
filtered_submodule, count_submodule = filter_functions(
445-
modified_functions_submodule,
446-
tests_root=Path("tests"),
411+
filtered_test_file, count_test_file = filter_functions(
412+
modified_functions_test,
413+
tests_root=tests_root_dir,
447414
ignore_paths=[],
448-
project_root=file_path.parent,
449-
module_root=file_path.parent,
415+
project_root=temp_dir,
416+
module_root=temp_dir,
450417
)
451-
assert not filtered_submodule
452-
assert count_submodule == 0
453-
454-
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages", return_value=True):
455-
with codeflash_temp_dir.joinpath("site_package_file.py").open("w") as f:
456-
f.write("def site_package_func(): return 1")
457-
458-
site_package_file_path = codeflash_temp_dir.joinpath("site_package_file.py")
459-
discovered_site_package = find_all_functions_in_file(site_package_file_path)
460-
modified_functions_site_package = {site_package_file_path: discovered_site_package.get(site_package_file_path, [])}
418+
419+
assert not filtered_test_file
420+
assert count_test_file == 0
421+
422+
# Test ignored directory
423+
ignored_dir = temp_dir.joinpath("ignored_dir")
424+
ignored_dir.mkdir(exist_ok=True)
425+
ignored_file_path = ignored_dir.joinpath("ignored_file.py")
426+
with ignored_file_path.open("w") as f:
427+
f.write("def ignored_func(): return 1")
428+
429+
discovered_ignored = find_all_functions_in_file(ignored_file_path)
430+
modified_functions_ignored = {ignored_file_path: discovered_ignored.get(ignored_file_path, [])}
461431

462-
filtered_site_package, count_site_package = filter_functions(
463-
modified_functions_site_package,
432+
filtered_ignored, count_ignored = filter_functions(
433+
modified_functions_ignored,
464434
tests_root=Path("tests"),
465-
ignore_paths=[],
466-
project_root=file_path.parent,
467-
module_root=file_path.parent,
435+
ignore_paths=[ignored_dir],
436+
project_root=temp_dir,
437+
module_root=temp_dir,
468438
)
469-
assert not filtered_site_package
470-
assert count_site_package == 0
471-
472-
outside_module_root_path = codeflash_temp_dir.parent.joinpath("outside_module_root_file.py")
473-
with outside_module_root_path.open("w") as f:
474-
f.write("def func_outside_module_root(): return 1")
475-
476-
discovered_outside_module = find_all_functions_in_file(outside_module_root_path)
477-
modified_functions_outside_module = {outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])}
478-
479-
filtered_outside_module, count_outside_module = filter_functions(
480-
modified_functions_outside_module,
481-
tests_root=Path("tests"),
482-
ignore_paths=[],
483-
project_root=file_path.parent,
484-
module_root=file_path.parent,
485-
)
486-
assert not filtered_outside_module
487-
assert count_outside_module == 0
488-
os.remove(outside_module_root_path)
489-
490-
invalid_module_file_path = codeflash_temp_dir.joinpath("invalid-module-name.py")
491-
with invalid_module_file_path.open("w") as f:
492-
f.write("def func_in_invalid_module(): return 1")
493-
494-
discovered_invalid_module = find_all_functions_in_file(invalid_module_file_path)
495-
modified_functions_invalid_module = {invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])}
496-
497-
filtered_invalid_module, count_invalid_module = filter_functions(
498-
modified_functions_invalid_module,
499-
tests_root=Path("tests"),
500-
ignore_paths=[],
501-
project_root=file_path.parent,
502-
module_root=file_path.parent,
503-
)
504-
assert not filtered_invalid_module
505-
assert count_invalid_module == 0
506-
507-
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={file_path.name: {"propagate_attributes", "other_blocklisted_function"}}):
508-
filtered_funcs, count= filter_functions(
509-
modified_functions,
510-
tests_root=Path("tests"),
511-
ignore_paths=[],
512-
project_root=file_path.parent,
513-
module_root=file_path.parent,
514-
)
515-
assert "propagate_attributes" not in [fn.function_name for fn in filtered_funcs.get(file_path, [])]
516-
assert count == 2
517-
518-
519-
module_name = "test_get_functions_to_optimize"
520-
qualified_name_for_checkpoint = f"{module_name}.propagate_attributes"
521-
other_qualified_name_for_checkpoint = f"{module_name}.vanilla_function"
522-
523-
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
524-
filtered_checkpoint, count_checkpoint = filter_functions(
525-
modified_functions,
439+
assert not filtered_ignored
440+
assert count_ignored == 0
441+
442+
# Test submodule paths
443+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.ignored_submodule_paths",
444+
return_value=[str(temp_dir.joinpath("submodule_dir"))]):
445+
submodule_dir = temp_dir.joinpath("submodule_dir")
446+
submodule_dir.mkdir(exist_ok=True)
447+
submodule_file_path = submodule_dir.joinpath("submodule_file.py")
448+
with submodule_file_path.open("w") as f:
449+
f.write("def submodule_func(): return 1")
450+
451+
discovered_submodule = find_all_functions_in_file(submodule_file_path)
452+
modified_functions_submodule = {submodule_file_path: discovered_submodule.get(submodule_file_path, [])}
453+
454+
filtered_submodule, count_submodule = filter_functions(
455+
modified_functions_submodule,
456+
tests_root=Path("tests"),
457+
ignore_paths=[],
458+
project_root=temp_dir,
459+
module_root=temp_dir,
460+
)
461+
assert not filtered_submodule
462+
assert count_submodule == 0
463+
464+
# Test site packages
465+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages",
466+
return_value=True):
467+
site_package_file_path = temp_dir.joinpath("site_package_file.py")
468+
with site_package_file_path.open("w") as f:
469+
f.write("def site_package_func(): return 1")
470+
471+
discovered_site_package = find_all_functions_in_file(site_package_file_path)
472+
modified_functions_site_package = {site_package_file_path: discovered_site_package.get(site_package_file_path, [])}
473+
474+
filtered_site_package, count_site_package = filter_functions(
475+
modified_functions_site_package,
476+
tests_root=Path("tests"),
477+
ignore_paths=[],
478+
project_root=temp_dir,
479+
module_root=temp_dir,
480+
)
481+
assert not filtered_site_package
482+
assert count_site_package == 0
483+
484+
# Test outside module root
485+
parent_dir = temp_dir.parent
486+
outside_module_root_path = parent_dir.joinpath("outside_module_root_file.py")
487+
try:
488+
with outside_module_root_path.open("w") as f:
489+
f.write("def func_outside_module_root(): return 1")
490+
491+
discovered_outside_module = find_all_functions_in_file(outside_module_root_path)
492+
modified_functions_outside_module = {outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])}
493+
494+
filtered_outside_module, count_outside_module = filter_functions(
495+
modified_functions_outside_module,
496+
tests_root=Path("tests"),
497+
ignore_paths=[],
498+
project_root=temp_dir,
499+
module_root=temp_dir,
500+
)
501+
assert not filtered_outside_module
502+
assert count_outside_module == 0
503+
finally:
504+
outside_module_root_path.unlink(missing_ok=True)
505+
506+
# Test invalid module name
507+
invalid_module_file_path = temp_dir.joinpath("invalid-module-name.py")
508+
with invalid_module_file_path.open("w") as f:
509+
f.write("def func_in_invalid_module(): return 1")
510+
511+
discovered_invalid_module = find_all_functions_in_file(invalid_module_file_path)
512+
modified_functions_invalid_module = {invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])}
513+
514+
filtered_invalid_module, count_invalid_module = filter_functions(
515+
modified_functions_invalid_module,
526516
tests_root=Path("tests"),
527517
ignore_paths=[],
528-
project_root=file_path.parent,
529-
module_root=file_path.parent,
530-
previous_checkpoint_functions={qualified_name_for_checkpoint: {"status": "optimized"}, other_qualified_name_for_checkpoint: {}}
518+
project_root=temp_dir,
519+
module_root=temp_dir,
531520
)
532-
assert filtered_checkpoint.get(file_path)
533-
assert count_checkpoint == 1
534-
535-
remaining_functions = [fn.function_name for fn in filtered_checkpoint.get(file_path, [])]
536-
assert "not_in_checkpoint_function" in remaining_functions
537-
assert "propagate_attributes" not in remaining_functions
538-
assert "vanilla_function" not in remaining_functions
521+
assert not filtered_invalid_module
522+
assert count_invalid_module == 0
523+
524+
original_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
525+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions",
526+
return_value={original_file_path.name: {"propagate_attributes", "other_blocklisted_function"}}):
527+
filtered_funcs, count = filter_functions(
528+
modified_functions,
529+
tests_root=Path("tests"),
530+
ignore_paths=[],
531+
project_root=temp_dir,
532+
module_root=temp_dir,
533+
)
534+
assert "propagate_attributes" not in [fn.function_name for fn in filtered_funcs.get(original_file_path, [])]
535+
assert count == 2
536+
537+
module_name = "test_get_functions_to_optimize"
538+
qualified_name_for_checkpoint = f"{module_name}.propagate_attributes"
539+
other_qualified_name_for_checkpoint = f"{module_name}.vanilla_function"
540+
541+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
542+
filtered_checkpoint, count_checkpoint = filter_functions(
543+
modified_functions,
544+
tests_root=Path("tests"),
545+
ignore_paths=[],
546+
project_root=temp_dir,
547+
module_root=temp_dir,
548+
previous_checkpoint_functions={qualified_name_for_checkpoint: {"status": "optimized"}, other_qualified_name_for_checkpoint: {}}
549+
)
550+
assert filtered_checkpoint.get(original_file_path)
551+
assert count_checkpoint == 1
552+
553+
remaining_functions = [fn.function_name for fn in filtered_checkpoint.get(original_file_path, [])]
554+
assert "not_in_checkpoint_function" in remaining_functions
555+
assert "propagate_attributes" not in remaining_functions
556+
assert "vanilla_function" not in remaining_functions
557+
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir)
558+
assert len(files_and_funcs) == 6

0 commit comments

Comments
 (0)