Skip to content

Commit 1d2bbf7

Browse files
Merge pull request #177 from codeflash-ai/checkpoint-for-codeflash-all-runs
Checkpoint for codeflash --all runs
2 parents 4e6bdc6 + 82e01af commit 1d2bbf7

File tree

5 files changed

+427
-39
lines changed

5 files changed

+427
-39
lines changed

codeflash/code_utils/checkpoint.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
import argparse
2+
import datetime
3+
import json
4+
import sys
5+
import time
6+
import uuid
7+
from pathlib import Path
8+
from typing import Any, Optional
9+
10+
import click
11+
12+
13+
class CodeflashRunCheckpoint:
14+
def __init__(self, module_root: Path, checkpoint_dir: Path = Path("/tmp")) -> None:
15+
self.module_root = module_root
16+
self.checkpoint_dir = Path(checkpoint_dir)
17+
# Create a unique checkpoint file name
18+
unique_id = str(uuid.uuid4())[:8]
19+
checkpoint_filename = f"codeflash_checkpoint_{unique_id}.jsonl"
20+
self.checkpoint_path = self.checkpoint_dir / checkpoint_filename
21+
22+
# Initialize the checkpoint file with metadata
23+
self._initialize_checkpoint_file()
24+
25+
def _initialize_checkpoint_file(self) -> None:
26+
"""Create a new checkpoint file with metadata."""
27+
metadata = {
28+
"type": "metadata",
29+
"module_root": str(self.module_root),
30+
"created_at": time.time(),
31+
"last_updated": time.time(),
32+
}
33+
34+
with open(self.checkpoint_path, "w") as f:
35+
f.write(json.dumps(metadata) + "\n")
36+
37+
def add_function_to_checkpoint(
38+
self,
39+
function_fully_qualified_name: str,
40+
status: str = "optimized",
41+
additional_info: Optional[dict[str, Any]] = None,
42+
) -> None:
43+
"""Add a function to the checkpoint after it has been processed.
44+
45+
Args:
46+
function_fully_qualified_name: The fully qualified name of the function
47+
status: Status of optimization (e.g., "optimized", "failed", "skipped")
48+
additional_info: Any additional information to store about the function
49+
50+
"""
51+
if additional_info is None:
52+
additional_info = {}
53+
54+
function_data = {
55+
"type": "function",
56+
"function_name": function_fully_qualified_name,
57+
"status": status,
58+
"timestamp": time.time(),
59+
**additional_info,
60+
}
61+
62+
with open(self.checkpoint_path, "a") as f:
63+
f.write(json.dumps(function_data) + "\n")
64+
65+
# Update the metadata last_updated timestamp
66+
self._update_metadata_timestamp()
67+
68+
def _update_metadata_timestamp(self) -> None:
69+
"""Update the last_updated timestamp in the metadata."""
70+
# Read the first line (metadata)
71+
with self.checkpoint_path.open() as f:
72+
metadata = json.loads(f.readline())
73+
rest_content = f.read()
74+
75+
# Update the timestamp
76+
metadata["last_updated"] = time.time()
77+
78+
# Write all lines to a temporary file
79+
80+
with self.checkpoint_path.open("w") as f:
81+
f.write(json.dumps(metadata) + "\n")
82+
f.write(rest_content)
83+
84+
def cleanup(self) -> None:
85+
"""Unlink all the checkpoint files for this module_root."""
86+
to_delete = []
87+
self.checkpoint_path.unlink(missing_ok=True)
88+
89+
for file in self.checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
90+
with file.open() as f:
91+
# Skip the first line (metadata)
92+
first_line = next(f)
93+
metadata = json.loads(first_line)
94+
if metadata.get("module_root", str(self.module_root)) == str(self.module_root):
95+
to_delete.append(file)
96+
for file in to_delete:
97+
file.unlink(missing_ok=True)
98+
99+
100+
def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
101+
"""Get information about all processed functions, regardless of status.
102+
103+
Returns:
104+
Dictionary mapping function names to their processing information
105+
106+
"""
107+
processed_functions = {}
108+
to_delete = []
109+
110+
for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
111+
with file.open() as f:
112+
# Skip the first line (metadata)
113+
first_line = next(f)
114+
metadata = json.loads(first_line)
115+
if metadata.get("last_updated"):
116+
last_updated = datetime.datetime.fromtimestamp(metadata["last_updated"])
117+
if datetime.datetime.now() - last_updated >= datetime.timedelta(days=7):
118+
to_delete.append(file)
119+
continue
120+
if metadata.get("module_root") != str(module_root):
121+
continue
122+
123+
for line in f:
124+
entry = json.loads(line)
125+
if entry.get("type") == "function":
126+
processed_functions[entry["function_name"]] = entry
127+
for file in to_delete:
128+
file.unlink(missing_ok=True)
129+
return processed_functions
130+
131+
132+
def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
133+
previous_checkpoint_functions = None
134+
if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir():
135+
previous_checkpoint_functions = get_all_historical_functions(args.module_root, Path("/tmp"))
136+
if previous_checkpoint_functions and click.confirm(
137+
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
138+
default=True,
139+
):
140+
pass
141+
else:
142+
previous_checkpoint_functions = None
143+
return previous_checkpoint_functions

codeflash/discovery/functions_to_optimize.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections import defaultdict
99
from functools import cache
1010
from pathlib import Path
11-
from typing import TYPE_CHECKING, Optional
11+
from typing import TYPE_CHECKING, Any, Optional
1212

1313
import git
1414
import libcst as cst
@@ -145,6 +145,7 @@ def qualified_name(self) -> str:
145145
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
146146
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
147147

148+
148149
def get_functions_to_optimize(
149150
optimize_all: str | None,
150151
replay_test: str | None,
@@ -154,10 +155,11 @@ def get_functions_to_optimize(
154155
ignore_paths: list[Path],
155156
project_root: Path,
156157
module_root: Path,
158+
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
157159
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
158-
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
159-
"Only one of optimize_all, replay_test, or file should be provided"
160-
)
160+
assert (
161+
sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1
162+
), "Only one of optimize_all, replay_test, or file should be provided"
161163
functions: dict[str, list[FunctionToOptimize]]
162164
with warnings.catch_warnings():
163165
warnings.simplefilter(action="ignore", category=SyntaxWarning)
@@ -198,7 +200,7 @@ def get_functions_to_optimize(
198200
ph("cli-optimizing-git-diff")
199201
functions = get_functions_within_git_diff()
200202
filtered_modified_functions, functions_count = filter_functions(
201-
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
203+
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
202204
)
203205
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
204206
return filtered_modified_functions, functions_count
@@ -414,6 +416,7 @@ def filter_functions(
414416
ignore_paths: list[Path],
415417
project_root: Path,
416418
module_root: Path,
419+
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
417420
disable_logs: bool = False,
418421
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
419422
blocklist_funcs = get_blocklisted_functions()
@@ -430,13 +433,16 @@ def filter_functions(
430433
ignore_paths_removed_count: int = 0
431434
malformed_paths_count: int = 0
432435
submodule_ignored_paths_count: int = 0
436+
blocklist_funcs_removed_count: int = 0
437+
previous_checkpoint_functions_removed_count: int = 0
433438
tests_root_str = str(tests_root)
434439
module_root_str = str(module_root)
435440
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
436441
for file_path_path, functions in modified_functions.items():
442+
_functions = functions
437443
file_path = str(file_path_path)
438444
if file_path.startswith(tests_root_str + os.sep):
439-
test_functions_removed_count += len(functions)
445+
test_functions_removed_count += len(_functions)
440446
continue
441447
if file_path in ignore_paths or any(
442448
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
@@ -449,27 +455,39 @@ def filter_functions(
449455
submodule_ignored_paths_count += 1
450456
continue
451457
if path_belongs_to_site_packages(Path(file_path)):
452-
site_packages_removed_count += len(functions)
458+
site_packages_removed_count += len(_functions)
453459
continue
454460
if not file_path.startswith(module_root_str + os.sep):
455-
non_modules_removed_count += len(functions)
461+
non_modules_removed_count += len(_functions)
456462
continue
457463
try:
458464
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
459465
except SyntaxError:
460466
malformed_paths_count += 1
461467
continue
462468
if blocklist_funcs:
463-
functions = [
464-
function
465-
for function in functions
469+
functions_tmp = []
470+
for function in _functions:
466471
if not (
467472
function.file_path.name in blocklist_funcs
468473
and function.qualified_name in blocklist_funcs[function.file_path.name]
469-
)
470-
]
471-
filtered_modified_functions[file_path] = functions
472-
functions_count += len(functions)
474+
):
475+
blocklist_funcs_removed_count += 1
476+
continue
477+
functions_tmp.append(function)
478+
_functions = functions_tmp
479+
480+
if previous_checkpoint_functions:
481+
functions_tmp = []
482+
for function in _functions:
483+
if function.qualified_name_with_modules_from_root(project_root) in previous_checkpoint_functions:
484+
previous_checkpoint_functions_removed_count += 1
485+
continue
486+
functions_tmp.append(function)
487+
_functions = functions_tmp
488+
489+
filtered_modified_functions[file_path] = _functions
490+
functions_count += len(_functions)
473491

474492
if not disable_logs:
475493
log_info = {
@@ -479,6 +497,8 @@ def filter_functions(
479497
f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count,
480498
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
481499
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
500+
f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count,
501+
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count,
482502
}
483503
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
484504
if log_string:

codeflash/optimization/function_optimizer.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
242242
# request for new optimizations but don't block execution, check for completion later
243243
# adding to control and experiment set but with same traceid
244244
best_optimization = None
245-
for _u, (candidates, exp_type) in enumerate(zip([optimizations_set.control, optimizations_set.experiment],["EXP0","EXP1"])):
245+
for _u, (candidates, exp_type) in enumerate(
246+
zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"])
247+
):
246248
if candidates is None:
247249
continue
248250

@@ -254,7 +256,14 @@ def optimize_function(self) -> Result[BestOptimization, str]:
254256
file_path_to_helper_classes=file_path_to_helper_classes,
255257
exp_type=exp_type,
256258
)
257-
ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id})
259+
ph(
260+
"cli-optimize-function-finished",
261+
{
262+
"function_trace_id": self.function_trace_id[:-4] + exp_type
263+
if self.experiment_id
264+
else self.function_trace_id
265+
},
266+
)
258267

259268
generated_tests = remove_functions_from_generated_tests(
260269
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
@@ -324,7 +333,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
324333
explanation=explanation,
325334
existing_tests_source=existing_tests,
326335
generated_original_test_source=generated_tests_str,
327-
function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
336+
function_trace_id=self.function_trace_id[:-4] + exp_type
337+
if self.experiment_id
338+
else self.function_trace_id,
328339
coverage_message=coverage_message,
329340
git_remote=self.args.git_remote,
330341
)
@@ -379,15 +390,19 @@ def determine_best_candidate(
379390
# Start a new thread for AI service request, start loop in main thread
380391
# check if aiservice request is complete, when it is complete, append result to the candidates list
381392
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
382-
ai_service_client = self.aiservice_client if exp_type=="EXP0" else self.local_aiservice_client
393+
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
383394
future_line_profile_results = executor.submit(
384395
ai_service_client.optimize_python_code_line_profiler,
385396
source_code=code_context.read_writable_code,
386397
dependency_code=code_context.read_only_context_code,
387398
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
388399
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
389400
num_candidates=10,
390-
experiment_metadata=ExperimentMetadata(id=self.experiment_id, group= "control" if exp_type == "EXP0" else "experiment") if self.experiment_id else None,
401+
experiment_metadata=ExperimentMetadata(
402+
id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment"
403+
)
404+
if self.experiment_id
405+
else None,
391406
)
392407
try:
393408
candidate_index = 0
@@ -462,7 +477,7 @@ def determine_best_candidate(
462477
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
463478
)
464479
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
465-
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
480+
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
466481
replay_perf_gain = {}
467482
if self.args.benchmark:
468483
test_results_by_benchmark = (
@@ -528,7 +543,9 @@ def determine_best_candidate(
528543
)
529544
return best_optimization
530545

531-
def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str) -> None:
546+
def log_successful_optimization(
547+
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
548+
) -> None:
532549
explanation_panel = Panel(
533550
f"⚡️ Optimization successful! 📄 {self.function_to_optimize.qualified_name} in {explanation.file_path}\n"
534551
f"📈 {explanation.perf_improvement_line}\n"
@@ -555,7 +572,9 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests:
555572
ph(
556573
"cli-optimize-success",
557574
{
558-
"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
575+
"function_trace_id": self.function_trace_id[:-4] + exp_type
576+
if self.experiment_id
577+
else self.function_trace_id,
559578
"speedup_x": explanation.speedup_x,
560579
"speedup_pct": explanation.speedup_pct,
561580
"best_runtime": explanation.best_runtime_ns,

0 commit comments

Comments
 (0)