Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions codeflash/code_utils/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import argparse
import datetime
import json
import sys
import time
import uuid
from pathlib import Path
from typing import Any, Optional

import click


class CodeflashRunCheckpoint:
def __init__(self, module_root: Path, checkpoint_dir: Path = Path("/tmp")) -> None:
self.module_root = module_root
self.checkpoint_dir = Path(checkpoint_dir)
# Create a unique checkpoint file name
unique_id = str(uuid.uuid4())[:8]
checkpoint_filename = f"codeflash_checkpoint_{unique_id}.jsonl"
self.checkpoint_path = self.checkpoint_dir / checkpoint_filename

# Initialize the checkpoint file with metadata
self._initialize_checkpoint_file()

def _initialize_checkpoint_file(self) -> None:
"""Create a new checkpoint file with metadata."""
metadata = {
"type": "metadata",
"module_root": str(self.module_root),
"created_at": time.time(),
"last_updated": time.time(),
}

with open(self.checkpoint_path, "w") as f:
f.write(json.dumps(metadata) + "\n")

def add_function_to_checkpoint(
self,
function_fully_qualified_name: str,
status: str = "optimized",
additional_info: Optional[dict[str, Any]] = None,
) -> None:
"""Add a function to the checkpoint after it has been processed.

Args:
function_fully_qualified_name: The fully qualified name of the function
status: Status of optimization (e.g., "optimized", "failed", "skipped")
additional_info: Any additional information to store about the function

"""
if additional_info is None:
additional_info = {}

function_data = {
"type": "function",
"function_name": function_fully_qualified_name,
"status": status,
"timestamp": time.time(),
**additional_info,
}

with open(self.checkpoint_path, "a") as f:
f.write(json.dumps(function_data) + "\n")

# Update the metadata last_updated timestamp
self._update_metadata_timestamp()

def _update_metadata_timestamp(self) -> None:
"""Update the last_updated timestamp in the metadata."""
# Read the first line (metadata)
with self.checkpoint_path.open() as f:
metadata = json.loads(f.readline())
rest_content = f.read()

# Update the timestamp
metadata["last_updated"] = time.time()

# Write all lines to a temporary file

with self.checkpoint_path.open("w") as f:
f.write(json.dumps(metadata) + "\n")
f.write(rest_content)

def cleanup(self) -> None:
"""Unlink all the checkpoint files for this module_root."""
to_delete = []
self.checkpoint_path.unlink(missing_ok=True)

for file in self.checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
with file.open() as f:
# Skip the first line (metadata)
first_line = next(f)
metadata = json.loads(first_line)
if metadata.get("module_root", str(self.module_root)) == str(self.module_root):
to_delete.append(file)
for file in to_delete:
file.unlink(missing_ok=True)


def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
"""Get information about all processed functions, regardless of status.

Returns:
Dictionary mapping function names to their processing information

"""
processed_functions = {}
to_delete = []

for file in checkpoint_dir.glob("codeflash_checkpoint_*.jsonl"):
with file.open() as f:
# Skip the first line (metadata)
first_line = next(f)
metadata = json.loads(first_line)
if metadata.get("last_updated"):
last_updated = datetime.datetime.fromtimestamp(metadata["last_updated"])
if datetime.datetime.now() - last_updated >= datetime.timedelta(days=7):
to_delete.append(file)
continue
if metadata.get("module_root") != str(module_root):
continue

for line in f:
entry = json.loads(line)
if entry.get("type") == "function":
processed_functions[entry["function_name"]] = entry
for file in to_delete:
file.unlink(missing_ok=True)
return processed_functions


def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
previous_checkpoint_functions = None
if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir():
previous_checkpoint_functions = get_all_historical_functions(args.module_root, Path("/tmp"))
if previous_checkpoint_functions and click.confirm(
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
default=True,
):
pass
else:
previous_checkpoint_functions = None
return previous_checkpoint_functions
50 changes: 35 additions & 15 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections import defaultdict
from functools import cache
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

import git
import libcst as cst
Expand Down Expand Up @@ -145,6 +145,7 @@ def qualified_name(self) -> str:
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"


def get_functions_to_optimize(
optimize_all: str | None,
replay_test: str | None,
Expand All @@ -154,10 +155,11 @@ def get_functions_to_optimize(
ignore_paths: list[Path],
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[str, dict[str, str]] | None = None,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
assert sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1, (
"Only one of optimize_all, replay_test, or file should be provided"
)
assert (
sum([bool(optimize_all), bool(replay_test), bool(file)]) <= 1
), "Only one of optimize_all, replay_test, or file should be provided"
functions: dict[str, list[FunctionToOptimize]]
with warnings.catch_warnings():
warnings.simplefilter(action="ignore", category=SyntaxWarning)
Expand Down Expand Up @@ -198,7 +200,7 @@ def get_functions_to_optimize(
ph("cli-optimizing-git-diff")
functions = get_functions_within_git_diff()
filtered_modified_functions, functions_count = filter_functions(
functions, test_cfg.tests_root, ignore_paths, project_root, module_root
functions, test_cfg.tests_root, ignore_paths, project_root, module_root, previous_checkpoint_functions
)
logger.info(f"Found {functions_count} function{'s' if functions_count > 1 else ''} to optimize")
return filtered_modified_functions, functions_count
Expand Down Expand Up @@ -414,6 +416,7 @@ def filter_functions(
ignore_paths: list[Path],
project_root: Path,
module_root: Path,
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
disable_logs: bool = False,
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
blocklist_funcs = get_blocklisted_functions()
Expand All @@ -430,13 +433,16 @@ def filter_functions(
ignore_paths_removed_count: int = 0
malformed_paths_count: int = 0
submodule_ignored_paths_count: int = 0
blocklist_funcs_removed_count: int = 0
previous_checkpoint_functions_removed_count: int = 0
tests_root_str = str(tests_root)
module_root_str = str(module_root)
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
for file_path_path, functions in modified_functions.items():
_functions = functions
file_path = str(file_path_path)
if file_path.startswith(tests_root_str + os.sep):
test_functions_removed_count += len(functions)
test_functions_removed_count += len(_functions)
continue
if file_path in ignore_paths or any(
file_path.startswith(str(ignore_path) + os.sep) for ignore_path in ignore_paths
Expand All @@ -449,27 +455,39 @@ def filter_functions(
submodule_ignored_paths_count += 1
continue
if path_belongs_to_site_packages(Path(file_path)):
site_packages_removed_count += len(functions)
site_packages_removed_count += len(_functions)
continue
if not file_path.startswith(module_root_str + os.sep):
non_modules_removed_count += len(functions)
non_modules_removed_count += len(_functions)
continue
try:
ast.parse(f"import {module_name_from_file_path(Path(file_path), project_root)}")
except SyntaxError:
malformed_paths_count += 1
continue
if blocklist_funcs:
functions = [
function
for function in functions
functions_tmp = []
for function in _functions:
if not (
function.file_path.name in blocklist_funcs
and function.qualified_name in blocklist_funcs[function.file_path.name]
)
]
filtered_modified_functions[file_path] = functions
functions_count += len(functions)
):
blocklist_funcs_removed_count += 1
continue
functions_tmp.append(function)
_functions = functions_tmp

if previous_checkpoint_functions:
functions_tmp = []
for function in _functions:
if function.qualified_name_with_modules_from_root(project_root) in previous_checkpoint_functions:
previous_checkpoint_functions_removed_count += 1
continue
functions_tmp.append(function)
_functions = functions_tmp

filtered_modified_functions[file_path] = _functions
functions_count += len(_functions)

if not disable_logs:
log_info = {
Expand All @@ -479,6 +497,8 @@ def filter_functions(
f"{non_modules_removed_count} function{'s' if non_modules_removed_count != 1 else ''} outside module-root": non_modules_removed_count,
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count,
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count,
}
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
if log_string:
Expand Down
35 changes: 27 additions & 8 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
# request for new optimizations but don't block execution, check for completion later
# adding to control and experiment set but with same traceid
best_optimization = None
for _u, (candidates, exp_type) in enumerate(zip([optimizations_set.control, optimizations_set.experiment],["EXP0","EXP1"])):
for _u, (candidates, exp_type) in enumerate(
zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"])
):
if candidates is None:
continue

Expand All @@ -254,7 +256,14 @@ def optimize_function(self) -> Result[BestOptimization, str]:
file_path_to_helper_classes=file_path_to_helper_classes,
exp_type=exp_type,
)
ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id})
ph(
"cli-optimize-function-finished",
{
"function_trace_id": self.function_trace_id[:-4] + exp_type
if self.experiment_id
else self.function_trace_id
},
)

generated_tests = remove_functions_from_generated_tests(
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
Expand Down Expand Up @@ -324,7 +333,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
explanation=explanation,
existing_tests_source=existing_tests,
generated_original_test_source=generated_tests_str,
function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
function_trace_id=self.function_trace_id[:-4] + exp_type
if self.experiment_id
else self.function_trace_id,
coverage_message=coverage_message,
git_remote=self.args.git_remote,
)
Expand Down Expand Up @@ -379,15 +390,19 @@ def determine_best_candidate(
# Start a new thread for AI service request, start loop in main thread
# check if aiservice request is complete, when it is complete, append result to the candidates list
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
ai_service_client = self.aiservice_client if exp_type=="EXP0" else self.local_aiservice_client
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
future_line_profile_results = executor.submit(
ai_service_client.optimize_python_code_line_profiler,
source_code=code_context.read_writable_code,
dependency_code=code_context.read_only_context_code,
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
num_candidates=10,
experiment_metadata=ExperimentMetadata(id=self.experiment_id, group= "control" if exp_type == "EXP0" else "experiment") if self.experiment_id else None,
experiment_metadata=ExperimentMetadata(
id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment"
)
if self.experiment_id
else None,
)
try:
candidate_index = 0
Expand Down Expand Up @@ -462,7 +477,7 @@ def determine_best_candidate(
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
)
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
replay_perf_gain = {}
if self.args.benchmark:
test_results_by_benchmark = (
Expand Down Expand Up @@ -528,7 +543,9 @@ def determine_best_candidate(
)
return best_optimization

def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str) -> None:
def log_successful_optimization(
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
) -> None:
explanation_panel = Panel(
f"⚡️ Optimization successful! 📄 {self.function_to_optimize.qualified_name} in {explanation.file_path}\n"
f"📈 {explanation.perf_improvement_line}\n"
Expand All @@ -555,7 +572,9 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests:
ph(
"cli-optimize-success",
{
"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
"function_trace_id": self.function_trace_id[:-4] + exp_type
if self.experiment_id
else self.function_trace_id,
"speedup_x": explanation.speedup_x,
"speedup_pct": explanation.speedup_pct,
"best_runtime": explanation.best_runtime_ns,
Expand Down
Loading
Loading