Skip to content

Commit e58d877

Browse files
committed
Add tests
1 parent cad5a07 commit e58d877

File tree

2 files changed

+177
-1
lines changed

2 files changed

+177
-1
lines changed

codeflash/code_utils/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class CodeflashRunCheckpoint:
14-
def __init__(self, module_root: Path, checkpoint_dir: str = "/tmp") -> None:
14+
def __init__(self, module_root: Path, checkpoint_dir: Path = Path("/tmp")) -> None:
1515
self.module_root = module_root
1616
self.checkpoint_dir = Path(checkpoint_dir)
1717
# Create a unique checkpoint file name

tests/test_codeflash_checkpoint.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import json
2+
import tempfile
3+
from pathlib import Path
4+
5+
import pytest
6+
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, get_all_historical_functions
7+
8+
9+
class TestCodeflashRunCheckpoint:
10+
@pytest.fixture
11+
def temp_dir(self):
12+
with tempfile.TemporaryDirectory() as temp_dir:
13+
yield Path(temp_dir)
14+
15+
def test_initialization(self, temp_dir):
16+
module_root = Path("/fake/module/root")
17+
checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir)
18+
19+
# Check if checkpoint file was created
20+
assert checkpoint.checkpoint_path.exists()
21+
22+
# Check if metadata was written correctly
23+
with open(checkpoint.checkpoint_path) as f:
24+
metadata = json.loads(f.readline())
25+
assert metadata["type"] == "metadata"
26+
assert metadata["module_root"] == str(module_root)
27+
assert "created_at" in metadata
28+
assert "last_updated" in metadata
29+
30+
def test_add_function_to_checkpoint(self, temp_dir):
31+
module_root = Path("/fake/module/root")
32+
checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir)
33+
34+
# Add a function to the checkpoint
35+
function_name = "module.submodule.function"
36+
checkpoint.add_function_to_checkpoint(function_name, status="optimized")
37+
38+
# Read the checkpoint file and verify
39+
with open(checkpoint.checkpoint_path) as f:
40+
lines = f.readlines()
41+
assert len(lines) == 2 # Metadata + function entry
42+
43+
function_data = json.loads(lines[1])
44+
assert function_data["type"] == "function"
45+
assert function_data["function_name"] == function_name
46+
assert function_data["status"] == "optimized"
47+
assert "timestamp" in function_data
48+
49+
def test_add_function_with_additional_info(self, temp_dir):
50+
module_root = Path("/fake/module/root")
51+
checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir)
52+
53+
# Add a function with additional info
54+
function_name = "module.submodule.function"
55+
additional_info = {"execution_time": 1.5, "memory_usage": "10MB"}
56+
checkpoint.add_function_to_checkpoint(function_name, status="optimized", additional_info=additional_info)
57+
58+
# Read the checkpoint file and verify
59+
with open(checkpoint.checkpoint_path) as f:
60+
lines = f.readlines()
61+
function_data = json.loads(lines[1])
62+
assert function_data["execution_time"] == 1.5
63+
assert function_data["memory_usage"] == "10MB"
64+
65+
def test_update_metadata_timestamp(self, temp_dir):
66+
module_root = Path("/fake/module/root")
67+
checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir)
68+
69+
# Get initial timestamp
70+
with open(checkpoint.checkpoint_path) as f:
71+
initial_metadata = json.loads(f.readline())
72+
initial_timestamp = initial_metadata["last_updated"]
73+
74+
# Wait a bit to ensure timestamp changes
75+
import time
76+
77+
time.sleep(0.01)
78+
79+
# Update timestamp
80+
checkpoint._update_metadata_timestamp()
81+
82+
# Check if timestamp was updated
83+
with open(checkpoint.checkpoint_path) as f:
84+
updated_metadata = json.loads(f.readline())
85+
updated_timestamp = updated_metadata["last_updated"]
86+
87+
assert updated_timestamp > initial_timestamp
88+
89+
def test_cleanup(self, temp_dir):
90+
module_root = Path("/fake/module/root")
91+
92+
# Create multiple checkpoint files
93+
checkpoint1 = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir)
94+
checkpoint2 = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir)
95+
96+
# Create a checkpoint for a different module
97+
different_module = Path("/different/module")
98+
checkpoint3 = CodeflashRunCheckpoint(different_module, checkpoint_dir=temp_dir)
99+
100+
# Verify all files exist
101+
assert checkpoint1.checkpoint_path.exists()
102+
assert checkpoint2.checkpoint_path.exists()
103+
assert checkpoint3.checkpoint_path.exists()
104+
105+
# Clean up files for module_root
106+
checkpoint1.cleanup()
107+
108+
# Check that only the files for module_root were deleted
109+
assert not checkpoint1.checkpoint_path.exists()
110+
assert not checkpoint2.checkpoint_path.exists()
111+
assert checkpoint3.checkpoint_path.exists()
112+
113+
114+
class TestGetAllHistoricalFunctions:
115+
@pytest.fixture
116+
def setup_checkpoint_files(self):
117+
with tempfile.TemporaryDirectory() as temp_dir:
118+
temp_dir_path = Path(temp_dir)
119+
module_root = Path("/fake/module/root")
120+
121+
# Create a checkpoint file with some functions
122+
checkpoint = CodeflashRunCheckpoint(module_root, checkpoint_dir=temp_dir_path)
123+
checkpoint.add_function_to_checkpoint("module.func1", status="optimized")
124+
checkpoint.add_function_to_checkpoint("module.func2", status="failed")
125+
126+
# Create an old checkpoint file (more than 7 days old)
127+
old_checkpoint_path = temp_dir_path / "codeflash_checkpoint_old.jsonl"
128+
with open(old_checkpoint_path, "w") as f:
129+
# Create metadata with old timestamp (8 days ago)
130+
import time
131+
132+
old_time = time.time() - (8 * 24 * 60 * 60)
133+
metadata = {
134+
"type": "metadata",
135+
"module_root": str(module_root),
136+
"created_at": old_time,
137+
"last_updated": old_time,
138+
}
139+
f.write(json.dumps(metadata) + "\n")
140+
141+
# Add a function entry
142+
function_data = {
143+
"type": "function",
144+
"function_name": "module.old_func",
145+
"status": "optimized",
146+
"timestamp": old_time,
147+
}
148+
f.write(json.dumps(function_data) + "\n")
149+
150+
# Create a checkpoint for a different module
151+
different_module = Path("/different/module")
152+
diff_checkpoint = CodeflashRunCheckpoint(different_module, checkpoint_dir=temp_dir_path)
153+
diff_checkpoint.add_function_to_checkpoint("different.func", status="optimized")
154+
155+
yield module_root, temp_dir_path
156+
157+
def test_get_all_historical_functions(self, setup_checkpoint_files):
158+
module_root, checkpoint_dir = setup_checkpoint_files
159+
160+
# Get historical functions
161+
functions = get_all_historical_functions(module_root, checkpoint_dir)
162+
163+
# Verify the functions from the current checkpoint are included
164+
assert "module.func1" in functions
165+
assert "module.func2" in functions
166+
assert functions["module.func1"]["status"] == "optimized"
167+
assert functions["module.func2"]["status"] == "failed"
168+
169+
# Verify the old function is not included (file should be deleted)
170+
assert "module.old_func" not in functions
171+
172+
# Verify the function from the different module is not included
173+
assert "different.func" not in functions
174+
175+
# Verify the old checkpoint file was deleted
176+
assert not (checkpoint_dir / "codeflash_checkpoint_old.jsonl").exists()

0 commit comments

Comments
 (0)