Skip to content

Commit d2461d7

Browse files
Merge pull request #873 from codeflash-ai/lsp/demo-optimization
[LSP] Demo optimization (find_common_tags)
2 parents 1f6cf3f + f810152 commit d2461d7

File tree

7 files changed

+164
-68
lines changed

7 files changed

+164
-68
lines changed

codeflash/api/cfapi.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import json
44
import os
5+
from dataclasses import dataclass
56
from functools import lru_cache
67
from pathlib import Path
78
from typing import TYPE_CHECKING, Any, Optional
@@ -26,14 +27,24 @@
2627

2728
from packaging import version
2829

29-
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
30-
CFAPI_BASE_URL = "http://localhost:3001"
31-
CFWEBAPP_BASE_URL = "http://localhost:3000"
32-
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
33-
console.rule()
34-
else:
35-
CFAPI_BASE_URL = "https://app.codeflash.ai"
36-
CFWEBAPP_BASE_URL = "https://app.codeflash.ai"
30+
31+
@dataclass
32+
class BaseUrls:
33+
cfapi_base_url: Optional[str] = None
34+
cfwebapp_base_url: Optional[str] = None
35+
36+
37+
@lru_cache(maxsize=1)
38+
def get_cfapi_base_urls() -> BaseUrls:
39+
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
40+
cfapi_base_url = "http://localhost:3001"
41+
cfwebapp_base_url = "http://localhost:3000"
42+
logger.info(f"Using local CF API at {cfapi_base_url}.")
43+
console.rule()
44+
else:
45+
cfapi_base_url = "https://app.codeflash.ai"
46+
cfwebapp_base_url = "https://app.codeflash.ai"
47+
return BaseUrls(cfapi_base_url=cfapi_base_url, cfwebapp_base_url=cfwebapp_base_url)
3748

3849

3950
def make_cfapi_request(
@@ -53,8 +64,9 @@ def make_cfapi_request(
5364
:param suppress_errors: If True, suppress error logging for HTTP errors.
5465
:return: The response object from the API.
5566
"""
56-
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
57-
cfapi_headers = {"Authorization": f"Bearer {api_key or get_codeflash_api_key()}"}
67+
url = f"{get_cfapi_base_urls().cfapi_base_url}/cfapi{endpoint}"
68+
final_api_key = api_key or get_codeflash_api_key()
69+
cfapi_headers = {"Authorization": f"Bearer {final_api_key}"}
5870
if extra_headers:
5971
cfapi_headers.update(extra_headers)
6072
try:

codeflash/cli_cmds/cmd_init.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
167167
console.rule()
168168

169169
if run_tests:
170-
bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args)
171-
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
170+
file_path = create_find_common_tags_file(args, "find_common_tags.py")
171+
run_end_to_end_test(args, file_path)
172172

173173

174174
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911
@@ -1207,6 +1207,35 @@ def enter_api_key_and_save_to_rc() -> None:
12071207
os.environ["CODEFLASH_API_KEY"] = api_key
12081208

12091209

1210+
def create_find_common_tags_file(args: Namespace, file_name: str) -> Path:
1211+
find_common_tags_content = """def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:
1212+
if not articles:
1213+
return set()
1214+
1215+
common_tags = articles[0]["tags"]
1216+
for article in articles[1:]:
1217+
common_tags = [tag for tag in common_tags if tag in article["tags"]]
1218+
return set(common_tags)
1219+
"""
1220+
1221+
file_path = Path(args.module_root) / file_name
1222+
lsp_enabled = is_LSP_enabled()
1223+
if file_path.exists() and not lsp_enabled:
1224+
from rich.prompt import Confirm
1225+
1226+
overwrite = Confirm.ask(
1227+
f"🤔 {file_path} already exists. Do you want to overwrite it?", default=True, show_default=False
1228+
)
1229+
if not overwrite:
1230+
apologize_and_exit()
1231+
console.rule()
1232+
1233+
file_path.write_text(find_common_tags_content, encoding="utf8")
1234+
logger.info(f"Created demo optimization file: {file_path}")
1235+
1236+
return file_path
1237+
1238+
12101239
def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]:
12111240
bubble_sort_content = """from typing import Union, List
12121241
def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
@@ -1276,7 +1305,7 @@ def test_sort():
12761305
return str(bubble_sort_path), str(bubble_sort_test_path)
12771306

12781307

1279-
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
1308+
def run_end_to_end_test(args: Namespace, find_common_tags_path: Path) -> None:
12801309
try:
12811310
check_formatter_installed(args.formatter_cmds)
12821311
except Exception:
@@ -1285,7 +1314,7 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
12851314
)
12861315
return
12871316

1288-
command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"]
1317+
command = ["codeflash", "--file", "find_common_tags.py", "--function", "find_common_tags"]
12891318
if args.no_pr:
12901319
command.append("--no-pr")
12911320
if args.verbose:
@@ -1316,10 +1345,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
13161345
console.rule()
13171346
# Delete the bubble_sort.py file after the test
13181347
logger.info("🧹 Cleaning up…")
1319-
for path in [bubble_sort_path, bubble_sort_test_path]:
1320-
console.rule()
1321-
Path(path).unlink(missing_ok=True)
1322-
logger.info(f"🗑️ Deleted {path}")
1348+
find_common_tags_path.unlink(missing_ok=True)
1349+
logger.info(f"🗑️ Deleted {find_common_tags_path}")
13231350

13241351

13251352
def ask_for_telemetry() -> bool:

codeflash/code_utils/git_worktree_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def remove_worktree(worktree_dir: Path) -> None:
8181

8282

8383
def create_diff_patch_from_worktree(
84-
worktree_dir: Path, files: list[str], fto_name: Optional[str] = None
84+
worktree_dir: Path, files: list[Path], fto_name: Optional[str] = None
8585
) -> Optional[Path]:
8686
repository = git.Repo(worktree_dir, search_parent_directories=True)
8787
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)

codeflash/lsp/beta.py

Lines changed: 93 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import threading
88
from dataclasses import dataclass
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Optional
10+
from typing import TYPE_CHECKING, Optional, Union
1111

1212
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1313
from codeflash.cli_cmds.cli import process_pyproject_config
@@ -16,12 +16,14 @@
1616
VsCodeSetupInfo,
1717
configure_pyproject_toml,
1818
create_empty_pyproject_toml,
19+
create_find_common_tags_file,
1920
get_formatter_cmds,
2021
get_suggestions,
2122
get_valid_subdirs,
2223
is_valid_pyproject_toml,
2324
)
2425
from codeflash.code_utils.git_utils import git_root_dir
26+
from codeflash.code_utils.git_worktree_utils import create_worktree_snapshot_commit
2527
from codeflash.code_utils.shell_utils import save_api_key_to_rc
2628
from codeflash.discovery.functions_to_optimize import (
2729
filter_functions,
@@ -39,6 +41,7 @@
3941
from lsprotocol import types
4042

4143
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
44+
from codeflash.lsp.server import WrappedInitializationResultT
4245

4346

4447
@dataclass
@@ -55,11 +58,15 @@ class FunctionOptimizationInitParams:
5558

5659
@dataclass
5760
class FunctionOptimizationParams:
58-
textDocument: types.TextDocumentIdentifier # noqa: N815
5961
functionName: str # noqa: N815
6062
task_id: str
6163

6264

65+
@dataclass
66+
class DemoOptimizationParams:
67+
functionName: str # noqa: N815
68+
69+
6370
@dataclass
6471
class ProvideApiKeyParams:
6572
api_key: str
@@ -257,10 +264,8 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]:
257264

258265
def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]:
259266
key_check_result = _check_api_key_validity(api_key)
260-
if key_check_result.get("status") != "success":
261-
return key_check_result
262-
263-
_init()
267+
if key_check_result.get("status") == "success":
268+
_init()
264269
return key_check_result
265270

266271

@@ -303,8 +308,8 @@ def _init() -> Namespace:
303308
def check_api_key(_params: any) -> dict[str, str]:
304309
try:
305310
return _initialize_optimizer_if_api_key_is_valid()
306-
except Exception:
307-
return {"status": "error", "message": "something went wrong while validating the api key"}
311+
except Exception as ex:
312+
return {"status": "error", "message": "something went wrong while validating the api key " + str(ex)}
308313

309314

310315
@server.feature("provideApiKey")
@@ -353,6 +358,56 @@ def cleanup_optimizer(_params: any) -> dict[str, str]:
353358
return {"status": "success"}
354359

355360

361+
def _initialize_current_function_optimizer() -> Union[dict[str, str], WrappedInitializationResultT]:
362+
"""Initialize the current function optimizer.
363+
364+
Returns:
365+
Union[dict[str, str], WrappedInitializationResultT]:
366+
error dict with status error,
367+
or a wrapped initializationresult if the optimizer is initialized.
368+
369+
"""
370+
if not server.optimizer:
371+
return {"status": "error", "message": "Optimizer not initialized yet."}
372+
373+
function_name = server.optimizer.args.function
374+
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
375+
376+
if count == 0:
377+
server.show_message_log(f"No optimizable functions found for {function_name}", "Warning")
378+
server.cleanup_the_optimizer()
379+
return {"functionName": function_name, "status": "error", "message": "not found", "args": None}
380+
381+
fto = optimizable_funcs.popitem()[1][0]
382+
383+
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
384+
if not module_prep_result:
385+
return {
386+
"functionName": function_name,
387+
"status": "error",
388+
"message": "Failed to prepare module for optimization",
389+
}
390+
391+
validated_original_code, original_module_ast = module_prep_result
392+
393+
function_optimizer = server.optimizer.create_function_optimizer(
394+
fto,
395+
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
396+
original_module_ast=original_module_ast,
397+
original_module_path=fto.file_path,
398+
function_to_tests={},
399+
)
400+
401+
server.optimizer.current_function_optimizer = function_optimizer
402+
if not function_optimizer:
403+
return {"functionName": function_name, "status": "error", "message": "No function optimizer found"}
404+
405+
initialization_result = function_optimizer.can_be_optimized()
406+
if not is_successful(initialization_result):
407+
return {"functionName": function_name, "status": "error", "message": initialization_result.failure()}
408+
return initialization_result
409+
410+
356411
@server.feature("initializeFunctionOptimization")
357412
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
358413
with execution_context(task_id=getattr(params, "task_id", None)):
@@ -377,52 +432,47 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) ->
377432
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
378433
)
379434

380-
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
381-
382-
if count == 0:
383-
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
384-
server.cleanup_the_optimizer()
385-
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
386-
387-
fto = optimizable_funcs.popitem()[1][0]
388-
389-
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
390-
if not module_prep_result:
391-
return {
392-
"functionName": params.functionName,
393-
"status": "error",
394-
"message": "Failed to prepare module for optimization",
395-
}
396-
397-
validated_original_code, original_module_ast = module_prep_result
398-
399-
function_optimizer = server.optimizer.create_function_optimizer(
400-
fto,
401-
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
402-
original_module_ast=original_module_ast,
403-
original_module_path=fto.file_path,
404-
function_to_tests={},
405-
)
406-
407-
server.optimizer.current_function_optimizer = function_optimizer
408-
if not function_optimizer:
409-
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
410-
411-
initialization_result = function_optimizer.can_be_optimized()
412-
if not is_successful(initialization_result):
413-
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
435+
initialization_result = _initialize_current_function_optimizer()
436+
if isinstance(initialization_result, dict):
437+
return initialization_result
414438

415439
server.current_optimization_init_result = initialization_result.unwrap()
416440
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
417441

418-
files = [function_optimizer.function_to_optimize.file_path]
442+
files = [document.path]
419443

420444
_, _, original_helpers = server.current_optimization_init_result
421445
files.extend([str(helper_path) for helper_path in original_helpers])
422446

423447
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
424448

425449

450+
@server.feature("startDemoOptimization")
451+
async def start_demo_optimization(params: DemoOptimizationParams) -> dict[str, str]:
452+
try:
453+
_init()
454+
# start by creating the worktree so that the demo file is not created in user workspace
455+
server.optimizer.worktree_mode()
456+
file_path = create_find_common_tags_file(server.args, params.functionName + ".py")
457+
# commit the new file for diff generation later
458+
create_worktree_snapshot_commit(server.optimizer.current_worktree, "added sample optimization file")
459+
460+
server.optimizer.args.file = file_path
461+
server.optimizer.args.function = params.functionName
462+
server.optimizer.args.previous_checkpoint_functions = False
463+
464+
initialization_result = _initialize_current_function_optimizer()
465+
if isinstance(initialization_result, dict):
466+
return initialization_result
467+
468+
server.current_optimization_init_result = initialization_result.unwrap()
469+
return await perform_function_optimization(
470+
FunctionOptimizationParams(functionName=params.functionName, task_id=None)
471+
)
472+
finally:
473+
server.cleanup_the_optimizer()
474+
475+
426476
@server.feature("performFunctionOptimization")
427477
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
428478
with execution_context(task_id=getattr(params, "task_id", None)):

codeflash/lsp/features/perform_optimization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: thr
5959
generated_perf_test_paths,
6060
instrumented_unittests_created_for_function,
6161
original_conftest_content,
62+
function_references,
6263
) = test_setup_result.unwrap()
6364

6465
baseline_setup_result = function_optimizer.setup_and_establish_baseline(
@@ -94,6 +95,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: thr
9495
generated_tests=generated_tests,
9596
test_functions_to_remove=test_functions_to_remove,
9697
concolic_test_str=concolic_test_str,
98+
function_references=function_references,
9799
)
98100

99101
abort_if_cancelled(cancel_event)

codeflash/lsp/server.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,27 @@
11
from __future__ import annotations
22

3+
from pathlib import Path
34
from typing import TYPE_CHECKING
45

56
from lsprotocol.types import LogMessageParams, MessageType
67
from pygls.lsp.server import LanguageServer
78
from pygls.protocol import LanguageServerProtocol
89

9-
if TYPE_CHECKING:
10-
from pathlib import Path
10+
from codeflash.either import Result
11+
from codeflash.models.models import CodeOptimizationContext
1112

12-
from codeflash.models.models import CodeOptimizationContext
13+
if TYPE_CHECKING:
1314
from codeflash.optimization.optimizer import Optimizer
1415

1516

1617
class CodeflashLanguageServerProtocol(LanguageServerProtocol):
1718
_server: CodeflashLanguageServer
1819

1920

21+
InitializationResultT = tuple[bool, CodeOptimizationContext, dict[Path, str]]
22+
WrappedInitializationResultT = Result[InitializationResultT, str]
23+
24+
2025
class CodeflashLanguageServer(LanguageServer):
2126
def __init__(self, name: str, version: str, protocol_cls: type[LanguageServerProtocol]) -> None:
2227
super().__init__(name, version, protocol_cls=protocol_cls)

0 commit comments

Comments
 (0)