Skip to content

Commit c3f1b1f

Browse files
create function optimizer during the optimization initialization step & remove optimization metadata logic
1 parent e973c69 commit c3f1b1f

File tree

4 files changed

+101
-182
lines changed

4 files changed

+101
-182
lines changed

codeflash/code_utils/git_worktree_utils.py

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
import json
43
import subprocess
54
import tempfile
65
import time
@@ -9,15 +8,12 @@
98
from typing import TYPE_CHECKING, Optional
109

1110
import git
12-
from filelock import FileLock
1311

1412
from codeflash.cli_cmds.console import logger
1513
from codeflash.code_utils.compat import codeflash_cache_dir
1614
from codeflash.code_utils.git_utils import check_running_in_git_repo, git_root_dir
1715

1816
if TYPE_CHECKING:
19-
from typing import Any
20-
2117
from git import Repo
2218

2319

@@ -100,71 +96,24 @@ def get_patches_dir_for_project() -> Path:
10096
return Path(patches_dir / project_id)
10197

10298

103-
def get_patches_metadata() -> dict[str, Any]:
104-
project_patches_dir = get_patches_dir_for_project()
105-
meta_file = project_patches_dir / "metadata.json"
106-
if meta_file.exists():
107-
with meta_file.open("r", encoding="utf-8") as f:
108-
return json.load(f)
109-
return {"id": get_git_project_id() or "", "patches": []}
110-
111-
112-
def save_patches_metadata(patch_metadata: dict) -> dict:
113-
project_patches_dir = get_patches_dir_for_project()
114-
meta_file = project_patches_dir / "metadata.json"
115-
lock_file = project_patches_dir / "metadata.json.lock"
116-
117-
# we are not supporting multiple concurrent optimizations within the same process, but keep that in case we decide to do so in the future.
118-
with FileLock(lock_file, timeout=10):
119-
metadata = get_patches_metadata()
120-
121-
patch_metadata["id"] = time.strftime("%Y%m%d-%H%M%S")
122-
metadata["patches"].append(patch_metadata)
123-
124-
meta_file.write_text(json.dumps(metadata, indent=2))
125-
126-
return patch_metadata
127-
128-
129-
def overwrite_patch_metadata(patches: list[dict]) -> bool:
130-
project_patches_dir = get_patches_dir_for_project()
131-
meta_file = project_patches_dir / "metadata.json"
132-
lock_file = project_patches_dir / "metadata.json.lock"
133-
134-
with FileLock(lock_file, timeout=10):
135-
metadata = get_patches_metadata()
136-
metadata["patches"] = patches
137-
meta_file.write_text(json.dumps(metadata, indent=2))
138-
return True
139-
140-
14199
def create_diff_patch_from_worktree(
142-
worktree_dir: Path,
143-
files: list[str],
144-
fto_name: Optional[str] = None,
145-
metadata_input: Optional[dict[str, Any]] = None,
146-
) -> dict[str, Any]:
100+
worktree_dir: Path, files: list[str], fto_name: Optional[str] = None
101+
) -> Optional[Path]:
147102
repository = git.Repo(worktree_dir, search_parent_directories=True)
148103
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
149104

150105
if not uni_diff_text:
151106
logger.warning("No changes found in worktree.")
152-
return {}
107+
return None
153108

154109
if not uni_diff_text.endswith("\n"):
155110
uni_diff_text += "\n"
156111

157112
project_patches_dir = get_patches_dir_for_project()
158113
project_patches_dir.mkdir(parents=True, exist_ok=True)
159114

160-
final_function_name = fto_name or metadata_input.get("fto_name", "unknown")
161-
patch_path = project_patches_dir / f"{worktree_dir.name}.{final_function_name}.patch"
115+
patch_path = project_patches_dir / f"{worktree_dir.name}.{fto_name}.patch"
162116
with patch_path.open("w", encoding="utf8") as f:
163117
f.write(uni_diff_text)
164118

165-
final_metadata = {"patch_path": str(patch_path)}
166-
if metadata_input:
167-
final_metadata.update(metadata_input)
168-
final_metadata = save_patches_metadata(final_metadata)
169-
170-
return final_metadata
119+
return patch_path

codeflash/lsp/beta.py

Lines changed: 91 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,8 @@
1212
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1313
from codeflash.cli_cmds.cli import process_pyproject_config
1414
from codeflash.cli_cmds.console import code_print
15-
from codeflash.code_utils.git_worktree_utils import (
16-
create_diff_patch_from_worktree,
17-
get_patches_metadata,
18-
overwrite_patch_metadata,
19-
)
15+
from codeflash.code_utils.git_utils import git_root_dir
16+
from codeflash.code_utils.git_worktree_utils import create_diff_patch_from_worktree
2017
from codeflash.code_utils.shell_utils import save_api_key_to_rc
2118
from codeflash.discovery.functions_to_optimize import (
2219
filter_functions,
@@ -39,10 +36,17 @@ class OptimizableFunctionsParams:
3936
textDocument: types.TextDocumentIdentifier # noqa: N815
4037

4138

39+
@dataclass
40+
class FunctionOptimizationInitParams:
41+
textDocument: types.TextDocumentIdentifier # noqa: N815
42+
functionName: str # noqa: N815
43+
44+
4245
@dataclass
4346
class FunctionOptimizationParams:
4447
textDocument: types.TextDocumentIdentifier # noqa: N815
4548
functionName: str # noqa: N815
49+
task_id: str
4650

4751

4852
@dataclass
@@ -59,7 +63,7 @@ class ValidateProjectParams:
5963

6064
@dataclass
6165
class OnPatchAppliedParams:
62-
patch_id: str
66+
task_id: str
6367

6468

6569
@dataclass
@@ -132,42 +136,6 @@ def get_optimizable_functions(
132136
return path_to_qualified_names
133137

134138

135-
@server.feature("initializeFunctionOptimization")
136-
def initialize_function_optimization(
137-
server: CodeflashLanguageServer, params: FunctionOptimizationParams
138-
) -> dict[str, str]:
139-
file_path = Path(uris.to_fs_path(params.textDocument.uri))
140-
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
141-
142-
if server.optimizer is None:
143-
_initialize_optimizer_if_api_key_is_valid(server)
144-
145-
server.optimizer.worktree_mode()
146-
147-
original_args, _ = server.optimizer.original_args_and_test_cfg
148-
149-
server.optimizer.args.function = params.functionName
150-
original_relative_file_path = file_path.relative_to(original_args.project_root)
151-
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
152-
server.optimizer.args.previous_checkpoint_functions = False
153-
154-
server.show_message_log(
155-
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
156-
)
157-
158-
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
159-
160-
if count == 0:
161-
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
162-
server.cleanup_the_optimizer()
163-
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
164-
165-
fto = optimizable_funcs.popitem()[1][0]
166-
server.optimizer.current_function_being_optimized = fto
167-
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
168-
return {"functionName": params.functionName, "status": "success"}
169-
170-
171139
def _find_pyproject_toml(workspace_path: str) -> Path | None:
172140
workspace_path_obj = Path(workspace_path)
173141
max_depth = 2
@@ -207,13 +175,18 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
207175
if pyproject_toml_path:
208176
server.prepare_optimizer_arguments(pyproject_toml_path)
209177
else:
210-
return {
211-
"status": "error",
212-
"message": "No pyproject.toml found in workspace.",
213-
} # TODO: enhancec this message to say there is not tool.codeflash in pyproject.toml or smth
178+
return {"status": "error", "message": "No pyproject.toml found in workspace."}
179+
180+
# since we are using worktrees, optimization diffs are generated with respect to the root of the repo, also the args.project_root is set to the root of the repo when creating a worktree
181+
root = str(git_root_dir())
214182

215183
if getattr(params, "skip_validation", False):
216-
return {"status": "success", "moduleRoot": server.args.module_root, "pyprojectPath": pyproject_toml_path}
184+
return {
185+
"status": "success",
186+
"moduleRoot": server.args.module_root,
187+
"pyprojectPath": pyproject_toml_path,
188+
"root": root,
189+
}
217190

218191
server.show_message_log("Validating project...", "Info")
219192
config = is_valid_pyproject_toml(pyproject_toml_path)
@@ -234,7 +207,7 @@ def init_project(server: CodeflashLanguageServer, params: ValidateProjectParams)
234207
except Exception:
235208
return {"status": "error", "message": "Repository has no commits (unborn HEAD)"}
236209

237-
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path}
210+
return {"status": "success", "moduleRoot": args.module_root, "pyprojectPath": pyproject_toml_path, "root": root}
238211

239212

240213
def _initialize_optimizer_if_api_key_is_valid(
@@ -296,78 +269,85 @@ def provide_api_key(server: CodeflashLanguageServer, params: ProvideApiKeyParams
296269
return {"status": "error", "message": "something went wrong while saving the api key"}
297270

298271

299-
@server.feature("retrieveSuccessfulOptimizations")
300-
def retrieve_successful_optimizations(_server: CodeflashLanguageServer, _params: any) -> dict[str, str]:
301-
metadata = get_patches_metadata()
302-
return {"status": "success", "patches": metadata["patches"]}
272+
@server.feature("initializeFunctionOptimization")
273+
def initialize_function_optimization(
274+
server: CodeflashLanguageServer, params: FunctionOptimizationInitParams
275+
) -> dict[str, str]:
276+
file_path = Path(uris.to_fs_path(params.textDocument.uri))
277+
server.show_message_log(f"Initializing optimization for function: {params.functionName} in {file_path}", "Info")
278+
279+
if server.optimizer is None:
280+
_initialize_optimizer_if_api_key_is_valid(server)
281+
282+
server.optimizer.worktree_mode()
303283

284+
original_args, _ = server.optimizer.original_args_and_test_cfg
304285

305-
@server.feature("onPatchApplied")
306-
def on_patch_applied(_server: CodeflashLanguageServer, params: OnPatchAppliedParams) -> dict[str, str]:
307-
# first remove the patch from the metadata
308-
metadata = get_patches_metadata()
286+
server.optimizer.args.function = params.functionName
287+
original_relative_file_path = file_path.relative_to(original_args.project_root)
288+
server.optimizer.args.file = server.optimizer.current_worktree / original_relative_file_path
289+
server.optimizer.args.previous_checkpoint_functions = False
309290

310-
deleted_patch_file = None
311-
new_patches = []
312-
for patch in metadata["patches"]:
313-
if patch["id"] == params.patch_id:
314-
deleted_patch_file = patch["patch_path"]
315-
continue
316-
new_patches.append(patch)
291+
server.show_message_log(
292+
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
293+
)
317294

318-
# then remove the patch file
319-
if deleted_patch_file:
320-
overwrite_patch_metadata(new_patches)
321-
patch_path = Path(deleted_patch_file)
322-
patch_path.unlink(missing_ok=True)
323-
return {"status": "success"}
324-
return {"status": "error", "message": "Patch not found"}
295+
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
325296

297+
if count == 0:
298+
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
299+
server.cleanup_the_optimizer()
300+
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
326301

327-
@server.feature("performFunctionOptimization")
328-
@server.thread()
329-
def perform_function_optimization( # noqa: PLR0911
330-
server: CodeflashLanguageServer, params: FunctionOptimizationParams
331-
) -> dict[str, str]:
332-
try:
333-
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
334-
current_function = server.optimizer.current_function_being_optimized
302+
fto = optimizable_funcs.popitem()[1][0]
335303

336-
if not current_function:
337-
server.show_message_log(f"No current function being optimized for {params.functionName}", "Error")
338-
return {
339-
"functionName": params.functionName,
340-
"status": "error",
341-
"message": "No function currently being optimized",
342-
}
304+
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
305+
if not module_prep_result:
306+
return {
307+
"functionName": params.functionName,
308+
"status": "error",
309+
"message": "Failed to prepare module for optimization",
310+
}
343311

344-
module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)
345-
if not module_prep_result:
346-
return {
347-
"functionName": params.functionName,
348-
"status": "error",
349-
"message": "Failed to prepare module for optimization",
350-
}
312+
validated_original_code, original_module_ast = module_prep_result
351313

352-
validated_original_code, original_module_ast = module_prep_result
314+
function_optimizer = server.optimizer.create_function_optimizer(
315+
fto,
316+
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
317+
original_module_ast=original_module_ast,
318+
original_module_path=fto.file_path,
319+
function_to_tests={},
320+
)
353321

354-
function_optimizer = server.optimizer.create_function_optimizer(
355-
current_function,
356-
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
357-
original_module_ast=original_module_ast,
358-
original_module_path=current_function.file_path,
359-
function_to_tests={},
360-
)
322+
server.optimizer.current_function_optimizer = function_optimizer
323+
if not function_optimizer:
324+
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
325+
326+
initialization_result = function_optimizer.can_be_optimized()
327+
if not is_successful(initialization_result):
328+
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
329+
330+
server.current_optimization_init_result = initialization_result.unwrap()
331+
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
332+
333+
files = [function_optimizer.function_to_optimize.file_path]
334+
335+
_, _, original_helpers = server.current_optimization_init_result
336+
files.extend([str(helper_path) for helper_path in original_helpers])
361337

362-
server.optimizer.current_function_optimizer = function_optimizer
363-
if not function_optimizer:
364-
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
338+
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
365339

366-
initialization_result = function_optimizer.can_be_optimized()
367-
if not is_successful(initialization_result):
368-
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
369340

370-
should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()
341+
@server.feature("performFunctionOptimization")
342+
@server.thread()
343+
def perform_function_optimization(
344+
server: CodeflashLanguageServer, params: FunctionOptimizationParams
345+
) -> dict[str, str]:
346+
try:
347+
server.show_message_log(f"Starting optimization for function: {params.functionName}", "Info")
348+
should_run_experiment, code_context, original_helper_code = server.current_optimization_init_result
349+
function_optimizer = server.optimizer.current_function_optimizer
350+
current_function = function_optimizer.function_to_optimize
371351

372352
code_print(
373353
code_context.read_writable_code.flat,
@@ -447,20 +427,8 @@ def perform_function_optimization( # noqa: PLR0911
447427

448428
speedup = original_code_baseline.runtime / best_optimization.runtime
449429

450-
# get the original file path in the actual project (not in the worktree)
451-
original_args, _ = server.optimizer.original_args_and_test_cfg
452-
relative_file_path = current_function.file_path.relative_to(server.optimizer.current_worktree)
453-
original_file_path = Path(original_args.project_root / relative_file_path).resolve()
454-
455-
metadata = create_diff_patch_from_worktree(
456-
server.optimizer.current_worktree,
457-
relative_file_paths,
458-
metadata_input={
459-
"fto_name": function_to_optimize_qualified_name,
460-
"explanation": best_optimization.explanation_v2,
461-
"file_path": str(original_file_path),
462-
"speedup": speedup,
463-
},
430+
patch_path = create_diff_patch_from_worktree(
431+
server.optimizer.current_worktree, relative_file_paths, function_to_optimize_qualified_name
464432
)
465433

466434
server.show_message_log(f"Optimization completed for {params.functionName} with {speedup:.2f}x speedup", "Info")
@@ -470,8 +438,8 @@ def perform_function_optimization( # noqa: PLR0911
470438
"status": "success",
471439
"message": "Optimization completed successfully",
472440
"extra": f"Speedup: {speedup:.2f}x faster",
473-
"patch_file": metadata["patch_path"],
474-
"patch_id": metadata["id"],
441+
"patch_file": patch_path,
442+
"task_id": params.task_id,
475443
"explanation": best_optimization.explanation_v2,
476444
}
477445
finally:

0 commit comments

Comments
 (0)