Skip to content

Commit b60a670

Browse files
committed
working version
1 parent 1971ef4 commit b60a670

File tree

4 files changed

+42
-27
lines changed

4 files changed

+42
-27
lines changed

codeflash/code_utils/checkpoint.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111

1212

1313
class CodeflashRunCheckpoint:
14-
def __init__(self, module_path: Path, checkpoint_dir: str = "/tmp") -> None:
15-
self.module_path = module_path
14+
def __init__(self, module_root: Path, checkpoint_dir: str = "/tmp") -> None:
15+
self.module_root = module_root
1616
self.checkpoint_dir = Path(checkpoint_dir)
1717
# Create a unique checkpoint file name
1818
unique_id = str(uuid.uuid4())[:8]
@@ -26,7 +26,7 @@ def _initialize_checkpoint_file(self) -> None:
2626
"""Create a new checkpoint file with metadata."""
2727
metadata = {
2828
"type": "metadata",
29-
"module_path": str(self.module_path),
29+
"module_root": str(self.module_root),
3030
"created_at": time.time(),
3131
"last_updated": time.time(),
3232
}
@@ -82,7 +82,7 @@ def _update_metadata_timestamp(self) -> None:
8282
f.write(rest_content)
8383

8484

85-
def get_all_historical_functions(checkpoint_dir: Path, module_path) -> dict[str, dict[str, str]]:
85+
def get_all_historical_functions(module_root: Path, checkpoint_dir: Path) -> dict[str, dict[str, str]]:
8686
"""Get information about all processed functions, regardless of status.
8787
8888
Returns:
@@ -97,29 +97,27 @@ def get_all_historical_functions(checkpoint_dir: Path, module_path) -> dict[str,
9797
# Skip the first line (metadata)
9898
first_line = next(f)
9999
metadata = json.loads(first_line)
100-
if metadata.get("timestamp"):
101-
metadata["timestamp"] = datetime.datetime.fromtimestamp(metadata["timestamp"])
102-
if metadata["timestamp"] >= datetime.datetime.now() - datetime.timedelta(days=7):
100+
if metadata.get("last_updated"):
101+
last_updated = datetime.datetime.fromtimestamp(metadata["last_updated"])
102+
if datetime.datetime.now() - last_updated >= datetime.timedelta(days=7):
103103
to_delete.append(file)
104104
continue
105-
else:
106-
metadata["timestamp"] = datetime.datetime.now()
107-
if metadata.get("module_path") != module_path:
105+
if metadata.get("module_root") != str(module_root):
108106
continue
109107

110108
for line in f:
111109
entry = json.loads(line)
112110
if entry.get("type") == "function":
113111
processed_functions[entry["function_name"]] = entry
114112
for file in to_delete:
115-
file.unlink()
113+
file.unlink(missing_ok=True)
116114
return processed_functions
117115

118116

119117
def ask_should_use_checkpoint_get_functions(args: argparse.Namespace) -> Optional[dict[str, dict[str, str]]]:
120118
previous_checkpoint_functions = None
121119
if args.all and (sys.platform == "linux" or sys.platform == "darwin") and Path("/tmp").is_dir():
122-
previous_checkpoint_functions = get_all_historical_functions(args.module_path, Path("/tmp"))
120+
previous_checkpoint_functions = get_all_historical_functions(args.module_root, Path("/tmp"))
123121
if previous_checkpoint_functions and click.confirm(
124122
"Previous Checkpoint detected from an incomplete optimization run, shall I continue the optimization from that point?",
125123
default=True,

codeflash/discovery/functions_to_optimize.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections import defaultdict
99
from functools import cache
1010
from pathlib import Path
11-
from typing import TYPE_CHECKING, Optional
11+
from typing import TYPE_CHECKING, Any, Optional
1212

1313
import git
1414
import libcst as cst
@@ -416,7 +416,7 @@ def filter_functions(
416416
ignore_paths: list[Path],
417417
project_root: Path,
418418
module_root: Path,
419-
previous_checkpoint_functions: dict[Path, list[FunctionToOptimize]] | None = None,
419+
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
420420
disable_logs: bool = False,
421421
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
422422
blocklist_funcs = get_blocklisted_functions()
@@ -480,9 +480,7 @@ def filter_functions(
480480
if previous_checkpoint_functions:
481481
functions_tmp = []
482482
for function in _functions:
483-
if function.file_path in previous_checkpoint_functions and function in previous_checkpoint_functions[
484-
function.file_path
485-
]:
483+
if function.qualified_name_with_modules_from_root(project_root) in previous_checkpoint_functions:
486484
previous_checkpoint_functions_removed_count += 1
487485
continue
488486
functions_tmp.append(function)
@@ -500,7 +498,7 @@ def filter_functions(
500498
f"{ignore_paths_removed_count} file{'s' if ignore_paths_removed_count != 1 else ''} from ignored paths": ignore_paths_removed_count,
501499
f"{submodule_ignored_paths_count} file{'s' if submodule_ignored_paths_count != 1 else ''} from ignored submodules": submodule_ignored_paths_count,
502500
f"{blocklist_funcs_removed_count} function{'s' if blocklist_funcs_removed_count != 1 else ''} as previously optimized": blocklist_funcs_removed_count,
503-
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} as previously optimized from checkpoint": previous_checkpoint_functions_removed_count,
501+
f"{previous_checkpoint_functions_removed_count} function{'s' if previous_checkpoint_functions_removed_count != 1 else ''} skipped from checkpoint": previous_checkpoint_functions_removed_count,
504502
}
505503
log_string = "\n".join([k for k, v in log_info.items() if v > 0])
506504
if log_string:

codeflash/optimization/function_optimizer.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
242242
# request for new optimizations but don't block execution, check for completion later
243243
# adding to control and experiment set but with same traceid
244244
best_optimization = None
245-
for _u, (candidates, exp_type) in enumerate(zip([optimizations_set.control, optimizations_set.experiment],["EXP0","EXP1"])):
245+
for _u, (candidates, exp_type) in enumerate(
246+
zip([optimizations_set.control, optimizations_set.experiment], ["EXP0", "EXP1"])
247+
):
246248
if candidates is None:
247249
continue
248250

@@ -254,7 +256,14 @@ def optimize_function(self) -> Result[BestOptimization, str]:
254256
file_path_to_helper_classes=file_path_to_helper_classes,
255257
exp_type=exp_type,
256258
)
257-
ph("cli-optimize-function-finished", {"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id})
259+
ph(
260+
"cli-optimize-function-finished",
261+
{
262+
"function_trace_id": self.function_trace_id[:-4] + exp_type
263+
if self.experiment_id
264+
else self.function_trace_id
265+
},
266+
)
258267

259268
generated_tests = remove_functions_from_generated_tests(
260269
generated_tests=generated_tests, test_functions_to_remove=test_functions_to_remove
@@ -324,7 +333,9 @@ def optimize_function(self) -> Result[BestOptimization, str]:
324333
explanation=explanation,
325334
existing_tests_source=existing_tests,
326335
generated_original_test_source=generated_tests_str,
327-
function_trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
336+
function_trace_id=self.function_trace_id[:-4] + exp_type
337+
if self.experiment_id
338+
else self.function_trace_id,
328339
coverage_message=coverage_message,
329340
git_remote=self.args.git_remote,
330341
)
@@ -379,15 +390,19 @@ def determine_best_candidate(
379390
# Start a new thread for AI service request, start loop in main thread
380391
# check if aiservice request is complete, when it is complete, append result to the candidates list
381392
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
382-
ai_service_client = self.aiservice_client if exp_type=="EXP0" else self.local_aiservice_client
393+
ai_service_client = self.aiservice_client if exp_type == "EXP0" else self.local_aiservice_client
383394
future_line_profile_results = executor.submit(
384395
ai_service_client.optimize_python_code_line_profiler,
385396
source_code=code_context.read_writable_code,
386397
dependency_code=code_context.read_only_context_code,
387398
trace_id=self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
388399
line_profiler_results=original_code_baseline.line_profile_results["str_out"],
389400
num_candidates=10,
390-
experiment_metadata=ExperimentMetadata(id=self.experiment_id, group= "control" if exp_type == "EXP0" else "experiment") if self.experiment_id else None,
401+
experiment_metadata=ExperimentMetadata(
402+
id=self.experiment_id, group="control" if exp_type == "EXP0" else "experiment"
403+
)
404+
if self.experiment_id
405+
else None,
391406
)
392407
try:
393408
candidate_index = 0
@@ -462,7 +477,7 @@ def determine_best_candidate(
462477
f"loop{'s' if candidate_result.max_loop_count > 1 else ''})"
463478
)
464479
tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%")
465-
tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X")
480+
tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X")
466481
replay_perf_gain = {}
467482
if self.args.benchmark:
468483
test_results_by_benchmark = (
@@ -528,7 +543,9 @@ def determine_best_candidate(
528543
)
529544
return best_optimization
530545

531-
def log_successful_optimization(self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str) -> None:
546+
def log_successful_optimization(
547+
self, explanation: Explanation, generated_tests: GeneratedTestsList, exp_type: str
548+
) -> None:
532549
explanation_panel = Panel(
533550
f"⚡️ Optimization successful! 📄 {self.function_to_optimize.qualified_name} in {explanation.file_path}\n"
534551
f"📈 {explanation.perf_improvement_line}\n"
@@ -555,7 +572,9 @@ def log_successful_optimization(self, explanation: Explanation, generated_tests:
555572
ph(
556573
"cli-optimize-success",
557574
{
558-
"function_trace_id": self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id,
575+
"function_trace_id": self.function_trace_id[:-4] + exp_type
576+
if self.experiment_id
577+
else self.function_trace_id,
559578
"speedup_x": explanation.speedup_x,
560579
"speedup_pct": explanation.speedup_pct,
561580
"best_runtime": explanation.best_runtime_ns,

codeflash/optimization/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def run(self) -> None:
163163
console.rule()
164164
ph("cli-optimize-discovered-tests", {"num_tests": num_discovered_tests})
165165
if self.args.all:
166-
self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_path)
166+
self.functions_checkpoint = CodeflashRunCheckpoint(self.args.module_root)
167167

168168
for original_module_path in file_to_funcs_to_optimize:
169169
logger.info(f"Examining file {original_module_path!s}…")

0 commit comments

Comments
 (0)