Skip to content

Commit beef54e

Browse files
demo optimization with find_common_tags
1 parent b9d5f16 commit beef54e

File tree

6 files changed

+159
-66
lines changed

6 files changed

+159
-66
lines changed

codeflash/api/cfapi.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
import os
55
import sys
6+
from dataclasses import dataclass
67
from functools import lru_cache
78
from pathlib import Path
89
from typing import TYPE_CHECKING, Any, Optional
@@ -26,14 +27,24 @@
2627

2728
from packaging import version
2829

29-
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
30-
CFAPI_BASE_URL = "http://localhost:3001"
31-
CFWEBAPP_BASE_URL = "http://localhost:3000"
32-
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
33-
console.rule()
34-
else:
35-
CFAPI_BASE_URL = "https://app.codeflash.ai"
36-
CFWEBAPP_BASE_URL = "https://app.codeflash.ai"
30+
31+
@dataclass
32+
class BaseUrls:
33+
cfapi_base_url: Optional[str] = None
34+
cfwebapp_base_url: Optional[str] = None
35+
36+
37+
@lru_cache(maxsize=1)
38+
def get_cfapi_base_urls() -> BaseUrls:
39+
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
40+
cfapi_base_url = "http://localhost:3001"
41+
cfwebapp_base_url = "http://localhost:3000"
42+
logger.info(f"Using local CF API at {cfapi_base_url}.")
43+
console.rule()
44+
else:
45+
cfapi_base_url = "https://app.codeflash.ai"
46+
cfwebapp_base_url = "https://app.codeflash.ai"
47+
return BaseUrls(cfapi_base_url=cfapi_base_url, cfwebapp_base_url=cfwebapp_base_url)
3748

3849

3950
def make_cfapi_request(
@@ -53,8 +64,9 @@ def make_cfapi_request(
5364
:param suppress_errors: If True, suppress error logging for HTTP errors.
5465
:return: The response object from the API.
5566
"""
56-
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
57-
cfapi_headers = {"Authorization": f"Bearer {api_key or get_codeflash_api_key()}"}
67+
url = f"{get_cfapi_base_urls().cfapi_base_url}/cfapi{endpoint}"
68+
final_api_key = api_key or get_codeflash_api_key()
69+
cfapi_headers = {"Authorization": f"Bearer {final_api_key}"}
5870
if extra_headers:
5971
cfapi_headers.update(extra_headers)
6072
try:

codeflash/cli_cmds/cmd_init.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,8 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
167167
console.rule()
168168

169169
if run_tests:
170-
bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args)
171-
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
170+
file_path = create_find_common_tags_file(args, "find_common_tags.py")
171+
run_end_to_end_test(args, file_path)
172172

173173

174174
def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911
@@ -1207,6 +1207,34 @@ def enter_api_key_and_save_to_rc() -> None:
12071207
os.environ["CODEFLASH_API_KEY"] = api_key
12081208

12091209

1210+
def create_find_common_tags_file(args: Namespace, file_name: str) -> Path:
1211+
find_common_tags_content = """def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:
1212+
if not articles:
1213+
return set()
1214+
1215+
common_tags = articles[0]["tags"]
1216+
for article in articles[1:]:
1217+
common_tags = [tag for tag in common_tags if tag in article["tags"]]
1218+
return set(common_tags)
1219+
"""
1220+
1221+
file_path = Path(args.module_root) / file_name
1222+
lsp_enabled = is_LSP_enabled()
1223+
if file_path.exists() and not lsp_enabled:
1224+
from rich.prompt import Confirm
1225+
1226+
overwrite = Confirm.ask(
1227+
f"🤔 {file_path} already exists. Do you want to overwrite it?", default=True, show_default=False
1228+
)
1229+
if not overwrite:
1230+
apologize_and_exit()
1231+
console.rule()
1232+
1233+
file_path.write_text(find_common_tags_content, encoding="utf8")
1234+
1235+
return file_path
1236+
1237+
12101238
def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]:
12111239
bubble_sort_content = """from typing import Union, List
12121240
def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
@@ -1276,7 +1304,7 @@ def test_sort():
12761304
return str(bubble_sort_path), str(bubble_sort_test_path)
12771305

12781306

1279-
def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
1307+
def run_end_to_end_test(args: Namespace, find_common_tags_path: Path) -> None:
12801308
try:
12811309
check_formatter_installed(args.formatter_cmds)
12821310
except Exception:
@@ -1285,7 +1313,7 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
12851313
)
12861314
return
12871315

1288-
command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"]
1316+
command = ["codeflash", "--file", "find_common_tags.py", "--function", "find_common_tags"]
12891317
if args.no_pr:
12901318
command.append("--no-pr")
12911319
if args.verbose:
@@ -1316,10 +1344,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
13161344
console.rule()
13171345
# Delete the bubble_sort.py file after the test
13181346
logger.info("🧹 Cleaning up…")
1319-
for path in [bubble_sort_path, bubble_sort_test_path]:
1320-
console.rule()
1321-
Path(path).unlink(missing_ok=True)
1322-
logger.info(f"🗑️ Deleted {path}")
1347+
find_common_tags_path.unlink(missing_ok=True)
1348+
logger.info(f"🗑️ Deleted {find_common_tags_path}")
13231349

13241350

13251351
def ask_for_telemetry() -> bool:

codeflash/code_utils/git_worktree_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def remove_worktree(worktree_dir: Path) -> None:
8181

8282

8383
def create_diff_patch_from_worktree(
84-
worktree_dir: Path, files: list[str], fto_name: Optional[str] = None
84+
worktree_dir: Path, files: list[Path], fto_name: Optional[str] = None
8585
) -> Optional[Path]:
8686
repository = git.Repo(worktree_dir, search_parent_directories=True)
8787
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)

codeflash/lsp/beta.py

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import threading
88
from dataclasses import dataclass
99
from pathlib import Path
10-
from typing import TYPE_CHECKING, Optional
10+
from typing import TYPE_CHECKING, Optional, Union
1111

1212
from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
1313
from codeflash.cli_cmds.cli import process_pyproject_config
@@ -16,12 +16,14 @@
1616
VsCodeSetupInfo,
1717
configure_pyproject_toml,
1818
create_empty_pyproject_toml,
19+
create_find_common_tags_file,
1920
get_formatter_cmds,
2021
get_suggestions,
2122
get_valid_subdirs,
2223
is_valid_pyproject_toml,
2324
)
2425
from codeflash.code_utils.git_utils import git_root_dir
26+
from codeflash.code_utils.git_worktree_utils import create_worktree_snapshot_commit
2527
from codeflash.code_utils.shell_utils import save_api_key_to_rc
2628
from codeflash.discovery.functions_to_optimize import (
2729
filter_functions,
@@ -38,6 +40,7 @@
3840
from lsprotocol import types
3941

4042
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
43+
from codeflash.lsp.server import WrappedInitializationResultT
4144

4245

4346
@dataclass
@@ -54,11 +57,15 @@ class FunctionOptimizationInitParams:
5457

5558
@dataclass
5659
class FunctionOptimizationParams:
57-
textDocument: types.TextDocumentIdentifier # noqa: N815
5860
functionName: str # noqa: N815
5961
task_id: str
6062

6163

64+
@dataclass
65+
class DemoOptimizationParams:
66+
functionName: str # noqa: N815
67+
68+
6269
@dataclass
6370
class ProvideApiKeyParams:
6471
api_key: str
@@ -256,10 +263,8 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]:
256263

257264
def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]:
258265
key_check_result = _check_api_key_validity(api_key)
259-
if key_check_result.get("status") != "success":
260-
return key_check_result
261-
262-
_init()
266+
if key_check_result.get("status") == "success":
267+
_init()
263268
return key_check_result
264269

265270

@@ -351,6 +356,56 @@ def cleanup_optimizer(_params: any) -> dict[str, str]:
351356
return {"status": "success"}
352357

353358

359+
def _initialize_current_function_optimizer() -> Union[dict[str, str], WrappedInitializationResultT]:
360+
"""Initialize the current function optimizer.
361+
362+
Returns:
363+
Union[dict[str, str], WrappedInitializationResultT]:
364+
error dict with status error,
365+
or a wrapped initializationresult if the optimizer is initialized.
366+
367+
"""
368+
if not server.optimizer:
369+
return {"status": "error", "message": "Optimizer not initialized yet."}
370+
371+
function_name = server.optimizer.args.function
372+
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
373+
374+
if count == 0:
375+
server.show_message_log(f"No optimizable functions found for {function_name}", "Warning")
376+
server.cleanup_the_optimizer()
377+
return {"functionName": function_name, "status": "error", "message": "not found", "args": None}
378+
379+
fto = optimizable_funcs.popitem()[1][0]
380+
381+
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
382+
if not module_prep_result:
383+
return {
384+
"functionName": function_name,
385+
"status": "error",
386+
"message": "Failed to prepare module for optimization",
387+
}
388+
389+
validated_original_code, original_module_ast = module_prep_result
390+
391+
function_optimizer = server.optimizer.create_function_optimizer(
392+
fto,
393+
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
394+
original_module_ast=original_module_ast,
395+
original_module_path=fto.file_path,
396+
function_to_tests={},
397+
)
398+
399+
server.optimizer.current_function_optimizer = function_optimizer
400+
if not function_optimizer:
401+
return {"functionName": function_name, "status": "error", "message": "No function optimizer found"}
402+
403+
initialization_result = function_optimizer.can_be_optimized()
404+
if not is_successful(initialization_result):
405+
return {"functionName": function_name, "status": "error", "message": initialization_result.failure()}
406+
return initialization_result
407+
408+
354409
@server.feature("initializeFunctionOptimization")
355410
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
356411
with execution_context(task_id=getattr(params, "task_id", None)):
@@ -375,52 +430,47 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) ->
375430
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
376431
)
377432

378-
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()
379-
380-
if count == 0:
381-
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
382-
server.cleanup_the_optimizer()
383-
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}
384-
385-
fto = optimizable_funcs.popitem()[1][0]
386-
387-
module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
388-
if not module_prep_result:
389-
return {
390-
"functionName": params.functionName,
391-
"status": "error",
392-
"message": "Failed to prepare module for optimization",
393-
}
394-
395-
validated_original_code, original_module_ast = module_prep_result
396-
397-
function_optimizer = server.optimizer.create_function_optimizer(
398-
fto,
399-
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
400-
original_module_ast=original_module_ast,
401-
original_module_path=fto.file_path,
402-
function_to_tests={},
403-
)
404-
405-
server.optimizer.current_function_optimizer = function_optimizer
406-
if not function_optimizer:
407-
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}
408-
409-
initialization_result = function_optimizer.can_be_optimized()
410-
if not is_successful(initialization_result):
411-
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
433+
initialization_result = _initialize_current_function_optimizer()
434+
if isinstance(initialization_result, dict):
435+
return initialization_result
412436

413437
server.current_optimization_init_result = initialization_result.unwrap()
414438
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")
415439

416-
files = [function_optimizer.function_to_optimize.file_path]
440+
files = [document.path]
417441

418442
_, _, original_helpers = server.current_optimization_init_result
419443
files.extend([str(helper_path) for helper_path in original_helpers])
420444

421445
return {"functionName": params.functionName, "status": "success", "files_inside_context": files}
422446

423447

448+
@server.feature("startDemoOptimization")
449+
async def start_demo_optimization(params: DemoOptimizationParams) -> dict[str, str]:
450+
try:
451+
_init()
452+
# start by creating the worktree so that the demo file is not created in user workspace
453+
server.optimizer.worktree_mode()
454+
file_path = create_find_common_tags_file(server.args, params.functionName + ".py")
455+
# commit the new file for diff generation later
456+
create_worktree_snapshot_commit(server.optimizer.current_worktree, "added sample optimization file")
457+
458+
server.optimizer.args.file = file_path
459+
server.optimizer.args.function = params.functionName
460+
server.optimizer.args.previous_checkpoint_functions = False
461+
462+
initialization_result = _initialize_current_function_optimizer()
463+
if isinstance(initialization_result, dict):
464+
return initialization_result
465+
466+
server.current_optimization_init_result = initialization_result.unwrap()
467+
return await perform_function_optimization(
468+
FunctionOptimizationParams(functionName=params.functionName, task_id=None)
469+
)
470+
finally:
471+
server.cleanup_the_optimizer()
472+
473+
424474
@server.feature("performFunctionOptimization")
425475
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
426476
with execution_context(task_id=getattr(params, "task_id", None)):

codeflash/lsp/server.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,35 @@
11
from __future__ import annotations
22

33
import contextvars
4+
from pathlib import Path
45
from typing import TYPE_CHECKING
56

67
from lsprotocol.types import LogMessageParams, MessageType
78
from pygls.lsp.server import LanguageServer
89
from pygls.protocol import LanguageServerProtocol
910

10-
if TYPE_CHECKING:
11-
from pathlib import Path
11+
from codeflash.either import Result
12+
from codeflash.models.models import CodeOptimizationContext
1213

13-
from codeflash.models.models import CodeOptimizationContext
14+
if TYPE_CHECKING:
1415
from codeflash.optimization.optimizer import Optimizer
1516

1617

1718
class CodeflashLanguageServerProtocol(LanguageServerProtocol):
1819
_server: CodeflashLanguageServer
1920

2021

22+
InitializationResultT = tuple[bool, CodeOptimizationContext, dict[Path, str]]
23+
WrappedInitializationResultT = Result[InitializationResultT, str]
24+
25+
2126
class CodeflashLanguageServer(LanguageServer):
2227
def __init__(self, name: str, version: str, protocol_cls: type[LanguageServerProtocol]) -> None:
2328
super().__init__(name, version, protocol_cls=protocol_cls)
2429
self.initialized: bool = False
2530
self.optimizer: Optimizer | None = None
2631
self.args = None
27-
self.current_optimization_init_result: tuple[bool, CodeOptimizationContext, dict[Path, str]] | None = None
32+
self.current_optimization_init_result: InitializationResultT | None = None
2833
self.execution_context_vars: contextvars.ContextVar[dict[str, str]] = contextvars.ContextVar(
2934
"execution_context_vars",
3035
default={}, # noqa: B039

codeflash/optimization/function_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from rich.tree import Tree
2020

2121
from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient
22-
from codeflash.api.cfapi import CFWEBAPP_BASE_URL, add_code_context_hash, create_staging, mark_optimization_success
22+
from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success
2323
from codeflash.benchmarking.utils import process_benchmark_data
2424
from codeflash.cli_cmds.console import code_print, console, logger, lsp_log, progress_bar
2525
from codeflash.code_utils import env_utils
@@ -1487,7 +1487,7 @@ def process_review(
14871487
response = create_staging(**data)
14881488
if response.status_code == 200:
14891489
trace_id = self.function_trace_id[:-4] + exp_type if self.experiment_id else self.function_trace_id
1490-
staging_url = f"{CFWEBAPP_BASE_URL}/review-optimizations/{trace_id}"
1490+
staging_url = f"{get_cfapi_base_urls().cfwebapp_base_url}/review-optimizations/{trace_id}"
14911491
console.print(
14921492
Panel(
14931493
f"[bold green]✅ Staging created:[/bold green]\n[link={staging_url}]{staging_url}[/link]",

0 commit comments

Comments
 (0)