Skip to content

Commit 5115295

Browse files
authored
Merge branch 'main' into add-timing-info-to-generated-tests
2 parents e6272e8 + 7b770b1 commit 5115295

File tree

2 files changed

+278
-9
lines changed

2 files changed

+278
-9
lines changed

codeflash/code_utils/compat.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,47 @@
11
import os
22
import sys
3+
import tempfile
34
from pathlib import Path
5+
from typing import TYPE_CHECKING
46

57
from platformdirs import user_config_dir
68

7-
# os-independent newline
8-
# important for any user-facing output or files we write
9-
# make sure to use this in f-strings e.g. f"some string{LF}"
10-
# you can use "[^f]\".*\{LF\}\" to find any lines in your code that use this without the f-string
11-
LF: str = os.linesep
9+
if TYPE_CHECKING:
10+
codeflash_temp_dir: Path
11+
codeflash_cache_dir: Path
12+
codeflash_cache_db: Path
1213

1314

14-
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
15+
class Compat:
16+
# os-independent newline
17+
LF: str = os.linesep
1518

16-
IS_POSIX = os.name != "nt"
19+
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
1720

21+
IS_POSIX: bool = os.name != "nt"
1822

19-
codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
23+
@property
24+
def codeflash_cache_dir(self) -> Path:
25+
return Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
2026

21-
codeflash_cache_db = codeflash_cache_dir / "codeflash_cache.db"
27+
@property
28+
def codeflash_temp_dir(self) -> Path:
29+
temp_dir = Path(tempfile.gettempdir()) / "codeflash"
30+
if not temp_dir.exists():
31+
temp_dir.mkdir(parents=True, exist_ok=True)
32+
return temp_dir
33+
34+
@property
35+
def codeflash_cache_db(self) -> Path:
36+
return self.codeflash_cache_dir / "codeflash_cache.db"
37+
38+
39+
_compat = Compat()
40+
41+
42+
codeflash_temp_dir = _compat.codeflash_temp_dir
43+
codeflash_cache_dir = _compat.codeflash_cache_dir
44+
codeflash_cache_db = _compat.codeflash_cache_db
45+
LF = _compat.LF
46+
SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE
47+
IS_POSIX = _compat.IS_POSIX

tests/test_function_discovery.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import tempfile
22
from pathlib import Path
3+
import os
4+
import unittest.mock
35

46
from codeflash.discovery.functions_to_optimize import (
57
filter_files_optimized,
68
find_all_functions_in_file,
79
get_functions_to_optimize,
810
inspect_top_level_functions_or_methods,
11+
filter_functions,
12+
get_all_files_and_functions
913
)
1014
from codeflash.verification.verification_utils import TestConfig
15+
from codeflash.code_utils.compat import codeflash_temp_dir
1116

1217

1318
def test_function_eligible_for_optimization() -> None:
@@ -313,3 +318,241 @@ def test_filter_files_optimized():
313318
assert filter_files_optimized(file_path_same_level, tests_root, ignore_paths, module_root)
314319
assert filter_files_optimized(file_path_different_level, tests_root, ignore_paths, module_root)
315320
assert not filter_files_optimized(file_path_above_level, tests_root, ignore_paths, module_root)
321+
322+
def test_filter_functions():
323+
with tempfile.TemporaryDirectory() as temp_dir_str:
324+
temp_dir = Path(temp_dir_str)
325+
326+
# Create a test file in the temporary directory
327+
test_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
328+
with test_file_path.open("w") as f:
329+
f.write(
330+
"""
331+
import copy
332+
333+
def propagate_attributes(
334+
nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
335+
) -> dict[str, dict]:
336+
modified_nodes = copy.deepcopy(nodes)
337+
338+
# Build an adjacency list for faster traversal
339+
adjacency = {}
340+
for edge in edges:
341+
src = edge["source"]
342+
tgt = edge["target"]
343+
if src not in adjacency:
344+
adjacency[src] = []
345+
adjacency[src].append(tgt)
346+
347+
# Track visited nodes to avoid cycles
348+
visited = set()
349+
350+
def traverse(node_id):
351+
if node_id in visited:
352+
return
353+
visited.add(node_id)
354+
355+
# Propagate attribute from source node
356+
if (
357+
node_id != source_node_id
358+
and source_node_id in modified_nodes
359+
and attribute in modified_nodes[source_node_id]
360+
):
361+
if node_id in modified_nodes:
362+
modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
363+
attribute
364+
]
365+
366+
# Continue propagation to neighbors
367+
for neighbor in adjacency.get(node_id, []):
368+
traverse(neighbor)
369+
370+
traverse(source_node_id)
371+
return modified_nodes
372+
373+
def vanilla_function():
374+
return "This is a vanilla function."
375+
376+
def not_in_checkpoint_function():
377+
return "This function is not in the checkpoint."
378+
"""
379+
)
380+
381+
382+
discovered = find_all_functions_in_file(test_file_path)
383+
modified_functions = {test_file_path: discovered[test_file_path]}
384+
filtered, count = filter_functions(
385+
modified_functions,
386+
tests_root=Path("tests"),
387+
ignore_paths=[],
388+
project_root=temp_dir,
389+
module_root=temp_dir,
390+
)
391+
function_names = [fn.function_name for fn in filtered.get(test_file_path, [])]
392+
assert "propagate_attributes" in function_names
393+
assert count == 3
394+
395+
# Create a tests directory inside our temp directory
396+
tests_root_dir = temp_dir.joinpath("tests")
397+
tests_root_dir.mkdir(exist_ok=True)
398+
399+
test_file_path = tests_root_dir.joinpath("test_functions.py")
400+
with test_file_path.open("w") as f:
401+
f.write(
402+
"""
403+
def test_function_in_tests_dir():
404+
return "This function is in a test directory and should be filtered out."
405+
"""
406+
)
407+
408+
discovered_test_file = find_all_functions_in_file(test_file_path)
409+
modified_functions_test = {test_file_path: discovered_test_file.get(test_file_path, [])}
410+
411+
filtered_test_file, count_test_file = filter_functions(
412+
modified_functions_test,
413+
tests_root=tests_root_dir,
414+
ignore_paths=[],
415+
project_root=temp_dir,
416+
module_root=temp_dir,
417+
)
418+
419+
assert not filtered_test_file
420+
assert count_test_file == 0
421+
422+
# Test ignored directory
423+
ignored_dir = temp_dir.joinpath("ignored_dir")
424+
ignored_dir.mkdir(exist_ok=True)
425+
ignored_file_path = ignored_dir.joinpath("ignored_file.py")
426+
with ignored_file_path.open("w") as f:
427+
f.write("def ignored_func(): return 1")
428+
429+
discovered_ignored = find_all_functions_in_file(ignored_file_path)
430+
modified_functions_ignored = {ignored_file_path: discovered_ignored.get(ignored_file_path, [])}
431+
432+
filtered_ignored, count_ignored = filter_functions(
433+
modified_functions_ignored,
434+
tests_root=Path("tests"),
435+
ignore_paths=[ignored_dir],
436+
project_root=temp_dir,
437+
module_root=temp_dir,
438+
)
439+
assert not filtered_ignored
440+
assert count_ignored == 0
441+
442+
# Test submodule paths
443+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.ignored_submodule_paths",
444+
return_value=[str(temp_dir.joinpath("submodule_dir"))]):
445+
submodule_dir = temp_dir.joinpath("submodule_dir")
446+
submodule_dir.mkdir(exist_ok=True)
447+
submodule_file_path = submodule_dir.joinpath("submodule_file.py")
448+
with submodule_file_path.open("w") as f:
449+
f.write("def submodule_func(): return 1")
450+
451+
discovered_submodule = find_all_functions_in_file(submodule_file_path)
452+
modified_functions_submodule = {submodule_file_path: discovered_submodule.get(submodule_file_path, [])}
453+
454+
filtered_submodule, count_submodule = filter_functions(
455+
modified_functions_submodule,
456+
tests_root=Path("tests"),
457+
ignore_paths=[],
458+
project_root=temp_dir,
459+
module_root=temp_dir,
460+
)
461+
assert not filtered_submodule
462+
assert count_submodule == 0
463+
464+
# Test site packages
465+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.path_belongs_to_site_packages",
466+
return_value=True):
467+
site_package_file_path = temp_dir.joinpath("site_package_file.py")
468+
with site_package_file_path.open("w") as f:
469+
f.write("def site_package_func(): return 1")
470+
471+
discovered_site_package = find_all_functions_in_file(site_package_file_path)
472+
modified_functions_site_package = {site_package_file_path: discovered_site_package.get(site_package_file_path, [])}
473+
474+
filtered_site_package, count_site_package = filter_functions(
475+
modified_functions_site_package,
476+
tests_root=Path("tests"),
477+
ignore_paths=[],
478+
project_root=temp_dir,
479+
module_root=temp_dir,
480+
)
481+
assert not filtered_site_package
482+
assert count_site_package == 0
483+
484+
# Test outside module root
485+
parent_dir = temp_dir.parent
486+
outside_module_root_path = parent_dir.joinpath("outside_module_root_file.py")
487+
try:
488+
with outside_module_root_path.open("w") as f:
489+
f.write("def func_outside_module_root(): return 1")
490+
491+
discovered_outside_module = find_all_functions_in_file(outside_module_root_path)
492+
modified_functions_outside_module = {outside_module_root_path: discovered_outside_module.get(outside_module_root_path, [])}
493+
494+
filtered_outside_module, count_outside_module = filter_functions(
495+
modified_functions_outside_module,
496+
tests_root=Path("tests"),
497+
ignore_paths=[],
498+
project_root=temp_dir,
499+
module_root=temp_dir,
500+
)
501+
assert not filtered_outside_module
502+
assert count_outside_module == 0
503+
finally:
504+
outside_module_root_path.unlink(missing_ok=True)
505+
506+
# Test invalid module name
507+
invalid_module_file_path = temp_dir.joinpath("invalid-module-name.py")
508+
with invalid_module_file_path.open("w") as f:
509+
f.write("def func_in_invalid_module(): return 1")
510+
511+
discovered_invalid_module = find_all_functions_in_file(invalid_module_file_path)
512+
modified_functions_invalid_module = {invalid_module_file_path: discovered_invalid_module.get(invalid_module_file_path, [])}
513+
514+
filtered_invalid_module, count_invalid_module = filter_functions(
515+
modified_functions_invalid_module,
516+
tests_root=Path("tests"),
517+
ignore_paths=[],
518+
project_root=temp_dir,
519+
module_root=temp_dir,
520+
)
521+
assert not filtered_invalid_module
522+
assert count_invalid_module == 0
523+
524+
original_file_path = temp_dir.joinpath("test_get_functions_to_optimize.py")
525+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions",
526+
return_value={original_file_path.name: {"propagate_attributes", "other_blocklisted_function"}}):
527+
filtered_funcs, count = filter_functions(
528+
modified_functions,
529+
tests_root=Path("tests"),
530+
ignore_paths=[],
531+
project_root=temp_dir,
532+
module_root=temp_dir,
533+
)
534+
assert "propagate_attributes" not in [fn.function_name for fn in filtered_funcs.get(original_file_path, [])]
535+
assert count == 2
536+
537+
module_name = "test_get_functions_to_optimize"
538+
qualified_name_for_checkpoint = f"{module_name}.propagate_attributes"
539+
other_qualified_name_for_checkpoint = f"{module_name}.vanilla_function"
540+
541+
with unittest.mock.patch("codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}):
542+
filtered_checkpoint, count_checkpoint = filter_functions(
543+
modified_functions,
544+
tests_root=Path("tests"),
545+
ignore_paths=[],
546+
project_root=temp_dir,
547+
module_root=temp_dir,
548+
previous_checkpoint_functions={qualified_name_for_checkpoint: {"status": "optimized"}, other_qualified_name_for_checkpoint: {}}
549+
)
550+
assert filtered_checkpoint.get(original_file_path)
551+
assert count_checkpoint == 1
552+
553+
remaining_functions = [fn.function_name for fn in filtered_checkpoint.get(original_file_path, [])]
554+
assert "not_in_checkpoint_function" in remaining_functions
555+
assert "propagate_attributes" not in remaining_functions
556+
assert "vanilla_function" not in remaining_functions
557+
files_and_funcs = get_all_files_and_functions(module_root_path=temp_dir)
558+
assert len(files_and_funcs) == 6

0 commit comments

Comments
 (0)