Skip to content

Commit 1971ef4

Browse files
committed
first attempt
1 parent 9639294 commit 1971ef4

File tree

3 files changed

+210
-30
lines changed

3 files changed

+210
-30
lines changed

codeflash/code_utils/checkpoint.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import argparse
2+
import datetime
3+
import json
4+
import sys
5+
import time
6+
import uuid
7+
from pathlib import Path
8+
from typing import Any, Optional
9+
10+
import click
11+
12+
13+
class CodeflashRunCheckpoint:
14+
def __init__(self, module_path: Path, checkpoint_dir: str = "/tmp") -> None:
15+
self.module_path = module_path
16+
self.checkpoint_dir = Path(checkpoint_dir)
17+
# Create a unique checkpoint file name
18+
unique_id = str(uuid.uuid4())[:8]
19+
checkpoint_filename = f"codeflash_checkpoint_{unique_id}.jsonl"
20+
self.checkpoint_path = self.checkpoint_dir / checkpoint_filename
21+
22+
# Initialize the checkpoint file with metadata
23+
self._initialize_checkpoint_file()
24+
25+
def _initialize_checkpoint_file(self) -> None:
26+
"""Create a new checkpoint file with metadata."""
27+
metadata = {
28+
"type": "metadata",
29+
"module_path": str(self.module_path),
30+
"created_at": time.time(),
31+
"last_updated": time.time(),
32+
}
33+
34+
with open(self.checkpoint_path, "w") as f:
35+
f.write(json.dumps(metadata) + "\n")
36+
37+
def add_function_to_checkpoint(
38+
self,
39+
function_fully_qualified_name: str,
40+
status: str = "optimized",
41+
additional_info: Optional[dict[str, Any]] = None,
42+
) -> None:
43+
"""Add a function to the checkpoint after it has been processed.
44+
45+
Args:
46+
function_fully_qualified_name: The fully qualified name of the function
47+
status: Status of optimization (e.g., "optimized", "failed", "skipped")
48+
additional_info: Any additional information to store about the function
49+
50+
"""
51+
if additional_info is None:
52+
additional_info = {}
53+
54+
function_data = {
55+
"type": "function",
56+
"function_name": function_fully_qualified_name,
57+
"status": status,
58+
"timestamp": time.time(),
59+
**additional_info,
60+
}
61+
62+
with open(self.checkpoint_path, "a") as f:
63+
f.write(json.dumps(function_data) + "\n")
64+
65+
# Update the metadata last_updated timestamp
66+
self._update_metadata_timestamp()
67+
68+
def _update_metadata_timestamp(self) -> None:
69+
"""Update the last_updated timestamp in the metadata."""
70+
# Read the first line (metadata)
71+
with self.checkpoint_path.open() as f:
72+
metadata = json.loads(f.readline())
73+
rest_content = f.read()
74+
75+
# Update the timestamp
76+
metadata["last_updated"] = time.time()
77+
78+
# Write all lines to a temporary file
79+
80+
with self.checkpoint_path.open("w") as f:
81+
f.write(json.dumps(metadata) + "\n")
82+
f.write(rest_content)
83+
84+
85+
def get_all_historical_functions(checkpoint_dir: Path, module_path) -> dict[str, dict[str, str]]:
86+
"""Get information about all processed functions, regardless of status.
87+
88+
Returns:
89+
Dictionary mapping function names to their processing information
90+
91+
"""
92+
processed_functions = {}
93+
to_delete = []
94+
95+
for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
96+
with file.open() as f:
97+
# Skip the first line (metadata)
98+
first_line = next(f)
99+
metadata = json.loads(first_line)
100+
if metadata.get("timestamp"):
101+
metadata["timestamp"] = datetime.datetime.fromtimestamp(metadata["timestamp"])
102+
if metadata["timestamp"] >= datetime.datetime.now() - datetime.timedelta(days=7):
103+
to_delete.append(file)
104+
continue
105+
else:
106+
metadata["timestamp"] = datetime.datetime.now()
107+
if metadata.get("module_path") != module_path:
108+
continue
109+
110+
for line in f:
111+
entry = json.loads(line)
112+
if entry.get("type") == "function":
113+
processed_functions[entry["function_name"]] = entry
114+
for file in to_delete:
115+
file.unlink()
116+
return processed_functions
117+
118+
119+
def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
120+
previous_checkpoint_functions = None
121+
if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir():
122+
previous_checkpoint_functions = get_all_historical_functions(args.module_path, Path("/tmp"))
123+
if previous_checkpoint_functions and click.confirm(
124+
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
125+
default=True,
126+
):
127+
pass
128+
else:
129+
previous_checkpoint_functions = None
130+
return previous_checkpoint_functions

codeflash/discovery/functions_to_optimize.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def qualified_name(self) -> str:
145145
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
146146
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"
147147

148+
148149
def get_functions_to_optimize(
149150
optimize_all: str | None,
150151
replay_test: str | None,
@@ -154,10 +155,11 @@ def get_functions_to_optimize(
154155
ignore_paths: list[Path],
155156
project_root: Path,
156157
module_root: Path,
158+
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
157159
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
158-
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
159-
"Only one of optimize_all, replay_test, or file should be provided"
160-
)
160+
assert (
161+
sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1
162+
), "Only one of optimize_all, replay_test, or file should be provided"
161163
functions: dict[str, list[FunctionToOptimize]]
162164
with warnings.catch_warnings():
163165
warnings.simplefilter(action="ignore", category=SyntaxWarning)
@@ -198,7 +200,7 @@ def get_functions_to_optimize(
198200
ph("cli-optimizing-git-diff")
199201
functions = get_functions_within_git_diff()
200202
filtered_modified_functions, functions_count = filter_functions(
201-
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
203+
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
202204
)
203205
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
204206
return filtered_modified_functions, functions_count
@@ -414,6 +416,7 @@ def filter_functions(
414416
ignore_paths: list[Path],
415417
project_root: Path,
416418
module_root: Path,
419+
previous_checkpoint_functions: dict[Path, list[FunctionToOptimize]] | None = None,
417420
disable_logs: bool = False,
418421
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
419422
blocklist_funcs = get_blocklisted_functions()
@@ -430,13 +433,16 @@ def filter_functions(
430433
ignore_paths_removed_count: int = 0
431434
malformed_paths_count: int = 0
432435
submodule_ignored_paths_count: int = 0
436+
blocklist_funcs_removed_count: int = 0
437+
previous_checkpoint_functions_removed_count: int = 0
433438
tests_root_str = str(tests_root)
434439
module_root_str = str(module_root)
435440
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
436441
for file_path_path, functions in modified_functions.items():
442+
_functions = functions
437443
file_path = str(file_path_path)
438444
if file_path.startswith(tests_root_str + os.sep):
439-
test_functions_removed_count += len(functions)
445+
test_functions_removed_count += len(_functions)
440446
continue
441447
if file_path in ignore_paths or any(
442448
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
@@ -449,27 +455,41 @@ def filter_functions(
449455
submodule_ignored_paths_count += 1
450456
continue
451457
if path_belongs_to_site_packages(Path(file_path)):
452-
site_packages_removed_count += len(functions)
458+
site_packages_removed_count += len(_functions)
453459
continue
454460
if not file_path.startswith(module_root_str + os.sep):
455-
non_modules_removed_count += len(functions)
461+
non_modules_removed_count += len(_functions)
456462
continue
457463
try:
458464
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
459465
except SyntaxError:
460466
malformed_paths_count += 1
461467
continue
462468
if blocklist_funcs:
463-
functions = [
464-
function
465-
for function in functions
469+
functions_tmp = []
470+
for function in _functions:
466471
if not (
467472
function.file_path.name in blocklist_funcs
468473
and function.qualified_name in blocklist_funcs[function.file_path.name]
469-
)
470-
]
471-
filtered_modified_functions[file_path] = functions
472-
functions_count += len(functions)
474+
):
475+
blocklist_funcs_removed_count += 1
476+
continue
477+
functions_tmp.append(function)
478+
_functions = functions_tmp
479+
480+
if previous_checkpoint_functions:
481+
functions_tmp = []
482+
for function in _functions:
483+
if function.file_path in previous_checkpoint_functions and function in previous_checkpoint_functions[
484+
function.file_path
485+
]:
486+
previous_checkpoint_functions_removed_count += 1
487+
continue
488+
functions_tmp.append(function)
489+
_functions = functions_tmp
490+
491+
filtered_modified_functions[file_path] = _functions
492+
functions_count += len(_functions)
473493

474494
if not disable_logs:
475495
log_info = {
@@ -479,6 +499,8 @@ def filter_functions(
479499
f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count,
480500
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
481501
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
502+
f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count,
503+
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} as previously optimized from checkpoint": previous_checkpoint_functions_removed_count,
482504
}
483505
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
484506
if log_string:

codeflash/optimization/optimizer.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from codeflash.benchmarking.utils import print_benchmark_table, validate_and_format_benchmark_table
1818
from codeflash.cli_cmds.console import console, logger, progress_bar
1919
from codeflash.code_utils import env_utils
20+
from codeflash.code_utils.checkpoint import CodeflashRunCheckpoint, ask_should_use_checkpoint_get_functions
2021
from codeflash.code_utils.code_replacer import normalize_code, normalize_node
2122
from codeflash.code_utils.code_utils import get_run_tmp_file
2223
from codeflash.code_utils.static_analysis import analyze_imported_modules, get_first_top_level_function_or_method_ast
@@ -52,6 +53,8 @@ def __init__(self, args: Namespace) -> None:
5253
self.experiment_id = os.getenv("CODEFLASH_EXPERIMENT_ID", None)
5354
self.local_aiservice_client = LocalAiServiceClient() if self.experiment_id else None
5455
self.replay_tests_dir = None
56+
self.functions_checkpoint: CodeflashRunCheckpoint | None = None
57+
5558
def create_function_optimizer(
5659
self,
5760
function_to_optimize: FunctionToOptimize,
@@ -71,7 +74,7 @@ def create_function_optimizer(
7174
args=self.args,
7275
function_benchmark_timings=function_benchmark_timings if function_benchmark_timings else None,
7376
total_benchmark_timings=total_benchmark_timings if total_benchmark_timings else None,
74-
replay_tests_dir = self.replay_tests_dir
77+
replay_tests_dir=self.replay_tests_dir,
7578
)
7679

7780
def run(self) -> None:
@@ -83,7 +86,7 @@ def run(self) -> None:
8386
function_optimizer = None
8487
file_to_funcs_to_optimize: dict[Path, list[FunctionToOptimize]]
8588
num_optimizable_functions: int
86-
89+
previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(self.args)
8790
# discover functions
8891
(file_to_funcs_to_optimize, num_optimizable_functions) = get_functions_to_optimize(
8992
optimize_all=self.args.all,
@@ -94,14 +97,12 @@ def run(self) -> None:
9497
ignore_paths=self.args.ignore_paths,
9598
project_root=self.args.project_root,
9699
module_root=self.args.module_root,
100+
previous_checkpoint_functions=previous_checkpoint_functions,
97101
)
98102
function_benchmark_timings: dict[str, dict[BenchmarkKey, int]] = {}
99103
total_benchmark_timings: dict[BenchmarkKey, int] = {}
100104
if self.args.benchmark and num_optimizable_functions > 0:
101-
with progress_bar(
102-
f"Running benchmarks in {self.args.benchmarks_root}",
103-
transient=True,
104-
):
105+
with progress_bar(f"Running benchmarks in {self.args.benchmarks_root}", transient=True):
105106
# Insert decorator
106107
file_path_to_source_code = defaultdict(str)
107108
for file in file_to_funcs_to_optimize:
@@ -113,15 +114,23 @@ def run(self) -> None:
113114
if trace_file.exists():
114115
trace_file.unlink()
115116

116-
self.replay_tests_dir = Path(tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root))
117-
trace_benchmarks_pytest(self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file) # Run all tests that use pytest-benchmark
117+
self.replay_tests_dir = Path(
118+
tempfile.mkdtemp(prefix="codeflash_replay_tests_", dir=self.args.benchmarks_root)
119+
)
120+
trace_benchmarks_pytest(
121+
self.args.benchmarks_root, self.args.tests_root, self.args.project_root, trace_file
122+
) # Run all tests that use pytest-benchmark
118123
replay_count = generate_replay_test(trace_file, self.replay_tests_dir)
119124
if replay_count == 0:
120-
logger.info(f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization")
125+
logger.info(
126+
f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization"
127+
)
121128
else:
122129
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
123130
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
124-
function_to_results = validate_and_format_benchmark_table(function_benchmark_timings, total_benchmark_timings)
131+
function_to_results = validate_and_format_benchmark_table(
132+
function_benchmark_timings, total_benchmark_timings
133+
)
125134
print_benchmark_table(function_to_results)
126135
except Exception as e:
127136
logger.info(f"Error while tracing existing benchmarks: {e}")
@@ -148,10 +157,13 @@ def run(self) -> None:
148157
function_to_tests: dict[str, list[FunctionCalledInTest]] = discover_unit_tests(self.test_cfg)
149158
num_discovered_tests: int = sum([len(value) for value in function_to_tests.values()])
150159
console.rule()
151-
logger.info(f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}")
160+
logger.info(
161+
f"Discovered {num_discovered_tests} existing unit tests in {(time.time() - start_time):.1f}s at {self.test_cfg.tests_root}"
162+
)
152163
console.rule()
153164
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
154-
165+
if self.args.all:
166+
self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_path)
155167

156168
for original_module_path in file_to_funcs_to_optimize:
157169
logger.info(f"Examining file {original_module_path!s}…")
@@ -212,17 +224,33 @@ def run(self) -> None:
212224
qualified_name_w_module = function_to_optimize.qualified_name_with_modules_from_root(
213225
self.args.project_root
214226
)
215-
if self.args.benchmark and function_benchmark_timings and qualified_name_w_module in function_benchmark_timings and total_benchmark_timings:
227+
if (
228+
self.args.benchmark
229+
and function_benchmark_timings
230+
and qualified_name_w_module in function_benchmark_timings
231+
and total_benchmark_timings
232+
):
216233
function_optimizer = self.create_function_optimizer(
217-
function_to_optimize, function_to_optimize_ast, function_to_tests, validated_original_code[original_module_path].source_code, function_benchmark_timings[qualified_name_w_module], total_benchmark_timings
234+
function_to_optimize,
235+
function_to_optimize_ast,
236+
function_to_tests,
237+
validated_original_code[original_module_path].source_code,
238+
function_benchmark_timings[qualified_name_w_module],
239+
total_benchmark_timings,
218240
)
219241
else:
220242
function_optimizer = self.create_function_optimizer(
221-
function_to_optimize, function_to_optimize_ast, function_to_tests,
222-
validated_original_code[original_module_path].source_code
243+
function_to_optimize,
244+
function_to_optimize_ast,
245+
function_to_tests,
246+
validated_original_code[original_module_path].source_code,
223247
)
224248

225249
best_optimization = function_optimizer.optimize_function()
250+
if self.functions_checkpoint:
251+
self.functions_checkpoint.add_function_to_checkpoint(
252+
function_to_optimize.qualified_name_with_modules_from_root(self.args.project_root)
253+
)
226254
if is_successful(best_optimization):
227255
optimizations_found += 1
228256
else:

0 commit comments

Comments
 (0)