Skip to content

Commit 223375f

Browse files
committed
first pass
1 parent 19dcbfb commit 223375f

File tree

2 files changed

+103
-9
lines changed

2 files changed

+103
-9
lines changed

codeflash/code_utils/compat.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,47 @@
11
import os
22
import sys
3+
import tempfile
34
from pathlib import Path
5+
from typing import TYPE_CHECKING
46

57
from platformdirs import user_config_dir
68

7-
# os-independent newline
8-
# important for any user-facing output or files we write
9-
# make sure to use this in f-strings e.g. f"some string{LF}"
10-
# you can use "[^f]\".*\{LF\}\" to find any lines in your code that use this without the f-string
11-
LF: str = os.linesep
9+
if TYPE_CHECKING:
10+
codeflash_temp_dir: Path
11+
codeflash_cache_dir: Path
12+
codeflash_cache_db: Path
1213

1314

14-
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
15+
class Compat:
16+
# os-independent newline
17+
LF: str = os.linesep
1518

16-
IS_POSIX = os.name != "nt"
19+
SAFE_SYS_EXECUTABLE: str = Path(sys.executable).as_posix()
1720

21+
IS_POSIX: bool = os.name != "nt"
1822

19-
codeflash_cache_dir = Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
23+
@property
24+
def codeflash_cache_dir(self) -> Path:
25+
return Path(user_config_dir(appname="codeflash", appauthor="codeflash-ai", ensure_exists=True))
2026

21-
codeflash_cache_db = codeflash_cache_dir / "codeflash_cache.db"
27+
@property
28+
def codeflash_temp_dir(self) -> Path:
29+
temp_dir = Path(tempfile.gettempdir()) / "codeflash"
30+
if not temp_dir.exists():
31+
temp_dir.mkdir(parents=True, exist_ok=True)
32+
return temp_dir
33+
34+
@property
35+
def codeflash_cache_db(self) -> Path:
36+
return self.codeflash_cache_dir / "codeflash_cache.db"
37+
38+
39+
_compat = Compat()
40+
41+
42+
codeflash_temp_dir = _compat.codeflash_temp_dir
43+
codeflash_cache_dir = _compat.codeflash_cache_dir
44+
codeflash_cache_db = _compat.codeflash_cache_db
45+
LF = _compat.LF
46+
SAFE_SYS_EXECUTABLE = _compat.SAFE_SYS_EXECUTABLE
47+
IS_POSIX = _compat.IS_POSIX

tests/test_function_discovery.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
find_all_functions_in_file,
77
get_functions_to_optimize,
88
inspect_top_level_functions_or_methods,
9+
filter_functions
910
)
1011
from codeflash.verification.verification_utils import TestConfig
12+
from codeflash.code_utils.compat import codeflash_temp_dir
1113

1214

1315
def test_function_eligible_for_optimization() -> None:
@@ -313,3 +315,69 @@ def test_filter_files_optimized():
313315
assert filter_files_optimized(file_path_same_level, tests_root, ignore_paths, module_root)
314316
assert filter_files_optimized(file_path_different_level, tests_root, ignore_paths, module_root)
315317
assert not filter_files_optimized(file_path_above_level, tests_root, ignore_paths, module_root)
318+
319+
def test_filter_functions():
320+
with codeflash_temp_dir.joinpath("test_get_functions_to_optimize.py").open("w") as f:
321+
f.write(
322+
"""
323+
import copy
324+
325+
def propagate_attributes(
326+
nodes: dict[str, dict], edges: list[dict], source_node_id: str, attribute: str
327+
) -> dict[str, dict]:
328+
modified_nodes = copy.deepcopy(nodes)
329+
330+
# Build an adjacency list for faster traversal
331+
adjacency = {}
332+
for edge in edges:
333+
src = edge["source"]
334+
tgt = edge["target"]
335+
if src not in adjacency:
336+
adjacency[src] = []
337+
adjacency[src].append(tgt)
338+
339+
# Track visited nodes to avoid cycles
340+
visited = set()
341+
342+
def traverse(node_id):
343+
if node_id in visited:
344+
return
345+
visited.add(node_id)
346+
347+
# Propagate attribute from source node
348+
if (
349+
node_id != source_node_id
350+
and source_node_id in modified_nodes
351+
and attribute in modified_nodes[source_node_id]
352+
):
353+
if node_id in modified_nodes:
354+
modified_nodes[node_id][attribute] = modified_nodes[source_node_id][
355+
attribute
356+
]
357+
358+
# Continue propagation to neighbors
359+
for neighbor in adjacency.get(node_id, []):
360+
traverse(neighbor)
361+
362+
traverse(source_node_id)
363+
return modified_nodes
364+
"""
365+
)
366+
f.flush()
367+
test_config = TestConfig(
368+
tests_root="tests", project_root_path=".", test_framework="pytest", tests_project_rootdir=Path()
369+
)
370+
371+
file_path = codeflash_temp_dir.joinpath("test_get_functions_to_optimize.py")
372+
discovered = find_all_functions_in_file(file_path)
373+
modified_functions = {file_path: discovered[file_path]}
374+
filtered, count = filter_functions(
375+
modified_functions,
376+
tests_root=Path("tests"),
377+
ignore_paths=[],
378+
project_root=file_path.parent,
379+
module_root=file_path.parent,
380+
)
381+
function_names = [fn.function_name for fn in filtered.get(file_path, [])]
382+
assert "propagate_attributes" in function_names
383+
assert count == 1

0 commit comments

Comments
 (0)