Skip to content

Commit 93801ac

Browse files
committed
cover more cases
1 parent 27c5791 commit 93801ac

File tree

1 file changed

+156
-1
lines changed

1 file changed

+156
-1
lines changed

tests/test_function_discovery.py

Lines changed: 156 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import tempfile
22
from pathlib import Path
3+
import os
4+
import unittest.mock
35

46
from codeflash.discovery.functions_to_optimize import (
57
filter_files_optimized,
@@ -361,6 +363,12 @@ def traverse(node_id):
361363
362364
traverse(source_node_id)
363365
return modified_nodes
366+
367+
def vanilla_function():
368+
return "This is a vanilla function."
369+
370+
def not_in_checkpoint_function():
371+
return "This function is not in the checkpoint."
364372
"""
365373
)
366374
f.flush()
@@ -380,4 +388,151 @@ def traverse(node_id):
380388
)
381389
function_names = [fn.function_name for fn in filtered.get(file_path, [])]
382390
assert "propagate_attributes" in function_names
383-
assert count == 1
391+
assert count == 3
392+
393+
tests_root_dir = codeflash_temp_dir.joinpath("tests")
394+
tests_root_dir.mkdir(exist_ok=True)
395+
396+
test_file_path = tests_root_dir.joinpath("test_functions.py")
397+
with test_file_path.open("w") as f:
398+
f.write(
399+
"""
400+
def test_function_in_tests_dir():
401+
return "This function is in a test directory and should be filtered out."
402+
"""
403+
)
404+
405+
discovered_test_file = find_all_functions_in_file(test_file_path)
406+
modified_functions_test = {test_file_path: discovered_test_file.get(test_file_path, [])}
407+
408+
filtered_test_file, count_test_file = filter_functions(
409+
modified_functions_test,
410+
tests_root=tests_root_dir,
411+
ignore_paths=[],
412+
project_root=codeflash_temp_dir,
413+
module_root=codeflash_temp_dir,
414+
)
415+
416+
assert not filtered_test_file
417+
assert count_test_file == 0
418+
419+
with codeflash_temp_dir.joinpath("ignored_dir").open("w") as f:
420+
f.write("def ignored_func(): return 1")
421+
422+
ignored_file_path = codeflash_temp_dir.joinpath("ignored_dir")
423+
discovered_ignored = find_all_functions_in_file(ignored_file_path)
424+
modified_functions_ignored = {ignored_file_path: discovered_ignored.get(ignored_file_path, [])}
425+
426+
filtered_ignored, count_ignored = filter_functions(
427+
modified_functions_ignored,
428+
tests_root=Path("tests"),
429+
ignore_paths=[ignored_file_path.parent],
430+
project_root=file_path.parent,
431+
module_root=file_path.parent,
432+
)
433+
assert not filtered_ignored
434+
assert count_ignored == 0
435+
436+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.ignored_submodule_paths", return_value=[str(codeflash_temp_dir.joinpath("submodule_dir"))]):
437+
with codeflash_temp_dir.joinpath("submodule_dir").open("w") as f:
438+
f.write("def submodule_func(): return 1")
439+
440+
submodule_file_path = codeflash_temp_dir.joinpath("submodule_dir")
441+
discovered_submodule = find_all_functions_in_file(submodule_file_path)
442+
modified_functions_submodule = {submodule_file_path: discovered_submodule.get(submodule_file_path, [])}
443+
444+
filtered_submodule, count_submodule = filter_functions(
445+
modified_functions_submodule,
446+
tests_root=Path("tests"),
447+
ignore_paths=[],
448+
project_root=file_path.parent,
449+
module_root=file_path.parent,
450+
)
451+
assert not filtered_submodule
452+
assert count_submodule == 0
453+
454+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages", return_value=True):
455+
with codeflash_temp_dir.joinpath("site_package_file.py").open("w") as f:
456+
f.write("def site_package_func(): return 1")
457+
458+
site_package_file_path = codeflash_temp_dir.joinpath("site_package_file.py")
459+
discovered_site_package = find_all_functions_in_file(site_package_file_path)
460+
modified_functions_site_package = {site_package_file_path: discovered_site_package.get(site_package_file_path, [])}
461+
462+
filtered_site_package, count_site_package = filter_functions(
463+
modified_functions_site_package,
464+
tests_root=Path("tests"),
465+
ignore_paths=[],
466+
project_root=file_path.parent,
467+
module_root=file_path.parent,
468+
)
469+
assert not filtered_site_package
470+
assert count_site_package == 0
471+
472+
outside_module_root_path = codeflash_temp_dir.parent.joinpath("outside_module_root_file.py")
473+
with outside_module_root_path.open("w") as f:
474+
f.write("def func_outside_module_root(): return 1")
475+
476+
discovered_outside_module = find_all_functions_in_file(outside_module_root_path)
477+
modified_functions_outside_module = {outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])}
478+
479+
filtered_outside_module, count_outside_module = filter_functions(
480+
modified_functions_outside_module,
481+
tests_root=Path("tests"),
482+
ignore_paths=[],
483+
project_root=file_path.parent,
484+
module_root=file_path.parent,
485+
)
486+
assert not filtered_outside_module
487+
assert count_outside_module == 0
488+
os.remove(outside_module_root_path)
489+
490+
invalid_module_file_path = codeflash_temp_dir.joinpath("invalid-module-name.py")
491+
with invalid_module_file_path.open("w") as f:
492+
f.write("def func_in_invalid_module(): return 1")
493+
494+
discovered_invalid_module = find_all_functions_in_file(invalid_module_file_path)
495+
modified_functions_invalid_module = {invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])}
496+
497+
filtered_invalid_module, count_invalid_module = filter_functions(
498+
modified_functions_invalid_module,
499+
tests_root=Path("tests"),
500+
ignore_paths=[],
501+
project_root=file_path.parent,
502+
module_root=file_path.parent,
503+
)
504+
assert not filtered_invalid_module
505+
assert count_invalid_module == 0
506+
507+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={file_path.name: {"propagate_attributes", "other_blocklisted_function"}}):
508+
filtered_blocklisted, count_blocklisted = filter_functions(
509+
modified_functions,
510+
tests_root=Path("tests"),
511+
ignore_paths=[],
512+
project_root=file_path.parent,
513+
module_root=file_path.parent,
514+
)
515+
assert "propagate_attributes" in [fn.function_name for fn in filtered_blocklisted.get(file_path, [])]
516+
assert count_blocklisted == 1
517+
518+
519+
module_name = "test_get_functions_to_optimize"
520+
qualified_name_for_checkpoint = f"{module_name}.propagate_attributes"
521+
other_qualified_name_for_checkpoint = f"{module_name}.vanilla_function"
522+
523+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
524+
filtered_checkpoint, count_checkpoint = filter_functions(
525+
modified_functions,
526+
tests_root=Path("tests"),
527+
ignore_paths=[],
528+
project_root=file_path.parent,
529+
module_root=file_path.parent,
530+
previous_checkpoint_functions={qualified_name_for_checkpoint: {"status": "optimized"}, other_qualified_name_for_checkpoint: {}}
531+
)
532+
assert filtered_checkpoint.get(file_path)
533+
assert count_checkpoint == 1
534+
535+
remaining_functions = [fn.function_name for fn in filtered_checkpoint.get(file_path, [])]
536+
assert "not_in_checkpoint_function" in remaining_functions
537+
assert "propagate_attributes" not in remaining_functions
538+
assert "vanilla_function" not in remaining_functions

0 commit comments

Comments
 (0)