Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions codeflash/api/cfapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import json
import os
from dataclasses import dataclass
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
Expand All @@ -26,14 +27,24 @@

from packaging import version

if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
CFAPI_BASE_URL = "http://localhost:3001"
CFWEBAPP_BASE_URL = "http://localhost:3000"
logger.info(f"Using local CF API at {CFAPI_BASE_URL}.")
console.rule()
else:
CFAPI_BASE_URL = "https://app.codeflash.ai"
CFWEBAPP_BASE_URL = "https://app.codeflash.ai"

@dataclass
class BaseUrls:
cfapi_base_url: Optional[str] = None
cfwebapp_base_url: Optional[str] = None


@lru_cache(maxsize=1)
def get_cfapi_base_urls() -> BaseUrls:
if os.environ.get("CODEFLASH_CFAPI_SERVER", "prod").lower() == "local":
cfapi_base_url = "http://localhost:3001"
cfwebapp_base_url = "http://localhost:3000"
logger.info(f"Using local CF API at {cfapi_base_url}.")
console.rule()
else:
cfapi_base_url = "https://app.codeflash.ai"
cfwebapp_base_url = "https://app.codeflash.ai"
return BaseUrls(cfapi_base_url=cfapi_base_url, cfwebapp_base_url=cfwebapp_base_url)


def make_cfapi_request(
Expand All @@ -53,8 +64,9 @@ def make_cfapi_request(
:param suppress_errors: If True, suppress error logging for HTTP errors.
:return: The response object from the API.
"""
url = f"{CFAPI_BASE_URL}/cfapi{endpoint}"
cfapi_headers = {"Authorization": f"Bearer {api_key or get_codeflash_api_key()}"}
url = f"{get_cfapi_base_urls().cfapi_base_url}/cfapi{endpoint}"
final_api_key = api_key or get_codeflash_api_key()
cfapi_headers = {"Authorization": f"Bearer {final_api_key}"}
if extra_headers:
cfapi_headers.update(extra_headers)
try:
Expand Down
43 changes: 35 additions & 8 deletions codeflash/cli_cmds/cmd_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def ask_run_end_to_end_test(args: Namespace) -> None:
console.rule()

if run_tests:
bubble_sort_path, bubble_sort_test_path = create_bubble_sort_file_and_test(args)
run_end_to_end_test(args, bubble_sort_path, bubble_sort_test_path)
file_path = create_find_common_tags_file(args, "find_common_tags.py")
run_end_to_end_test(args, file_path)


def is_valid_pyproject_toml(pyproject_toml_path: Path) -> tuple[bool, dict[str, Any] | None, str]: # noqa: PLR0911
Expand Down Expand Up @@ -1207,6 +1207,35 @@ def enter_api_key_and_save_to_rc() -> None:
os.environ["CODEFLASH_API_KEY"] = api_key


def create_find_common_tags_file(args: Namespace, file_name: str) -> Path:
find_common_tags_content = """def find_common_tags(articles: list[dict[str, list[str]]]) -> set[str]:
if not articles:
return set()

common_tags = articles[0]["tags"]
for article in articles[1:]:
common_tags = [tag for tag in common_tags if tag in article["tags"]]
return set(common_tags)
"""

file_path = Path(args.module_root) / file_name
lsp_enabled = is_LSP_enabled()
if file_path.exists() and not lsp_enabled:
from rich.prompt import Confirm

overwrite = Confirm.ask(
f"🤔 {file_path} already exists. Do you want to overwrite it?", default=True, show_default=False
)
if not overwrite:
apologize_and_exit()
console.rule()

file_path.write_text(find_common_tags_content, encoding="utf8")
logger.info(f"Created demo optimization file: {file_path}")

return file_path


def create_bubble_sort_file_and_test(args: Namespace) -> tuple[str, str]:
bubble_sort_content = """from typing import Union, List
def sorter(arr: Union[List[int],List[float]]) -> Union[List[int],List[float]]:
Expand Down Expand Up @@ -1276,7 +1305,7 @@ def test_sort():
return str(bubble_sort_path), str(bubble_sort_test_path)


def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test_path: str) -> None:
def run_end_to_end_test(args: Namespace, find_common_tags_path: Path) -> None:
try:
check_formatter_installed(args.formatter_cmds)
except Exception:
Expand All @@ -1285,7 +1314,7 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
)
return

command = ["codeflash", "--file", "bubble_sort.py", "--function", "sorter"]
command = ["codeflash", "--file", "find_common_tags.py", "--function", "find_common_tags"]
if args.no_pr:
command.append("--no-pr")
if args.verbose:
Expand Down Expand Up @@ -1316,10 +1345,8 @@ def run_end_to_end_test(args: Namespace, bubble_sort_path: str, bubble_sort_test
console.rule()
# Delete the bubble_sort.py file after the test
logger.info("🧹 Cleaning up…")
for path in [bubble_sort_path, bubble_sort_test_path]:
console.rule()
Path(path).unlink(missing_ok=True)
logger.info(f"🗑️ Deleted {path}")
find_common_tags_path.unlink(missing_ok=True)
logger.info(f"🗑️ Deleted {find_common_tags_path}")


def ask_for_telemetry() -> bool:
Expand Down
2 changes: 1 addition & 1 deletion codeflash/code_utils/git_worktree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def remove_worktree(worktree_dir: Path) -> None:


def create_diff_patch_from_worktree(
worktree_dir: Path, files: list[str], fto_name: Optional[str] = None
worktree_dir: Path, files: list[Path], fto_name: Optional[str] = None
) -> Optional[Path]:
repository = git.Repo(worktree_dir, search_parent_directories=True)
uni_diff_text = repository.git.diff(None, "HEAD", *files, ignore_blank_lines=True, ignore_space_at_eol=True)
Expand Down
136 changes: 93 additions & 43 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from codeflash.api.cfapi import get_codeflash_api_key, get_user_id
from codeflash.cli_cmds.cli import process_pyproject_config
Expand All @@ -16,12 +16,14 @@
VsCodeSetupInfo,
configure_pyproject_toml,
create_empty_pyproject_toml,
create_find_common_tags_file,
get_formatter_cmds,
get_suggestions,
get_valid_subdirs,
is_valid_pyproject_toml,
)
from codeflash.code_utils.git_utils import git_root_dir
from codeflash.code_utils.git_worktree_utils import create_worktree_snapshot_commit
from codeflash.code_utils.shell_utils import save_api_key_to_rc
from codeflash.discovery.functions_to_optimize import (
filter_functions,
Expand All @@ -39,6 +41,7 @@
from lsprotocol import types

from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.lsp.server import WrappedInitializationResultT


@dataclass
Expand All @@ -55,11 +58,15 @@ class FunctionOptimizationInitParams:

@dataclass
class FunctionOptimizationParams:
textDocument: types.TextDocumentIdentifier # noqa: N815
functionName: str # noqa: N815
task_id: str


@dataclass
class DemoOptimizationParams:
functionName: str # noqa: N815


@dataclass
class ProvideApiKeyParams:
api_key: str
Expand Down Expand Up @@ -257,10 +264,8 @@ def init_project(params: ValidateProjectParams) -> dict[str, str]:

def _initialize_optimizer_if_api_key_is_valid(api_key: Optional[str] = None) -> dict[str, str]:
key_check_result = _check_api_key_validity(api_key)
if key_check_result.get("status") != "success":
return key_check_result

_init()
if key_check_result.get("status") == "success":
_init()
return key_check_result


Expand Down Expand Up @@ -303,8 +308,8 @@ def _init() -> Namespace:
def check_api_key(_params: any) -> dict[str, str]:
try:
return _initialize_optimizer_if_api_key_is_valid()
except Exception:
return {"status": "error", "message": "something went wrong while validating the api key"}
except Exception as ex:
return {"status": "error", "message": "something went wrong while validating the api key " + str(ex)}


@server.feature("provideApiKey")
Expand Down Expand Up @@ -353,6 +358,56 @@ def cleanup_optimizer(_params: any) -> dict[str, str]:
return {"status": "success"}


def _initialize_current_function_optimizer() -> Union[dict[str, str], WrappedInitializationResultT]:
"""Initialize the current function optimizer.

Returns:
Union[dict[str, str], WrappedInitializationResultT]:
error dict with status error,
or a wrapped initializationresult if the optimizer is initialized.

"""
if not server.optimizer:
return {"status": "error", "message": "Optimizer not initialized yet."}

function_name = server.optimizer.args.function
optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()

if count == 0:
server.show_message_log(f"No optimizable functions found for {function_name}", "Warning")
server.cleanup_the_optimizer()
return {"functionName": function_name, "status": "error", "message": "not found", "args": None}

fto = optimizable_funcs.popitem()[1][0]

module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
if not module_prep_result:
return {
"functionName": function_name,
"status": "error",
"message": "Failed to prepare module for optimization",
}

validated_original_code, original_module_ast = module_prep_result

function_optimizer = server.optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
)

server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": function_name, "status": "error", "message": "No function optimizer found"}

initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": function_name, "status": "error", "message": initialization_result.failure()}
return initialization_result


@server.feature("initializeFunctionOptimization")
def initialize_function_optimization(params: FunctionOptimizationInitParams) -> dict[str, str]:
with execution_context(task_id=getattr(params, "task_id", None)):
Expand All @@ -377,52 +432,47 @@ def initialize_function_optimization(params: FunctionOptimizationInitParams) ->
f"Args set - function: {server.optimizer.args.function}, file: {server.optimizer.args.file}", "Info"
)

optimizable_funcs, count, _ = server.optimizer.get_optimizable_functions()

if count == 0:
server.show_message_log(f"No optimizable functions found for {params.functionName}", "Warning")
server.cleanup_the_optimizer()
return {"functionName": params.functionName, "status": "error", "message": "not found", "args": None}

fto = optimizable_funcs.popitem()[1][0]

module_prep_result = server.optimizer.prepare_module_for_optimization(fto.file_path)
if not module_prep_result:
return {
"functionName": params.functionName,
"status": "error",
"message": "Failed to prepare module for optimization",
}

validated_original_code, original_module_ast = module_prep_result

function_optimizer = server.optimizer.create_function_optimizer(
fto,
function_to_optimize_source_code=validated_original_code[fto.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=fto.file_path,
function_to_tests={},
)

server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}

initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}
initialization_result = _initialize_current_function_optimizer()
if isinstance(initialization_result, dict):
return initialization_result

server.current_optimization_init_result = initialization_result.unwrap()
server.show_message_log(f"Successfully initialized optimization for {params.functionName}", "Info")

files = [function_optimizer.function_to_optimize.file_path]
files = [document.path]

_, _, original_helpers = server.current_optimization_init_result
files.extend([str(helper_path) for helper_path in original_helpers])

return {"functionName": params.functionName, "status": "success", "files_inside_context": files}


@server.feature("startDemoOptimization")
async def start_demo_optimization(params: DemoOptimizationParams) -> dict[str, str]:
try:
_init()
# start by creating the worktree so that the demo file is not created in user workspace
server.optimizer.worktree_mode()
file_path = create_find_common_tags_file(server.args, params.functionName + ".py")
# commit the new file for diff generation later
create_worktree_snapshot_commit(server.optimizer.current_worktree, "added sample optimization file")

server.optimizer.args.file = file_path
server.optimizer.args.function = params.functionName
server.optimizer.args.previous_checkpoint_functions = False

initialization_result = _initialize_current_function_optimizer()
if isinstance(initialization_result, dict):
return initialization_result

server.current_optimization_init_result = initialization_result.unwrap()
return await perform_function_optimization(
FunctionOptimizationParams(functionName=params.functionName, task_id=None)
)
finally:
server.cleanup_the_optimizer()


@server.feature("performFunctionOptimization")
async def perform_function_optimization(params: FunctionOptimizationParams) -> dict[str, str]:
with execution_context(task_id=getattr(params, "task_id", None)):
Expand Down
2 changes: 2 additions & 0 deletions codeflash/lsp/features/perform_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: thr
generated_perf_test_paths,
instrumented_unittests_created_for_function,
original_conftest_content,
function_references,
) = test_setup_result.unwrap()

baseline_setup_result = function_optimizer.setup_and_establish_baseline(
Expand Down Expand Up @@ -94,6 +95,7 @@ def sync_perform_optimization(server: CodeflashLanguageServer, cancel_event: thr
generated_tests=generated_tests,
test_functions_to_remove=test_functions_to_remove,
concolic_test_str=concolic_test_str,
function_references=function_references,
)

abort_if_cancelled(cancel_event)
Expand Down
11 changes: 8 additions & 3 deletions codeflash/lsp/server.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

from lsprotocol.types import LogMessageParams, MessageType
from pygls.lsp.server import LanguageServer
from pygls.protocol import LanguageServerProtocol

if TYPE_CHECKING:
from pathlib import Path
from codeflash.either import Result
from codeflash.models.models import CodeOptimizationContext

from codeflash.models.models import CodeOptimizationContext
if TYPE_CHECKING:
from codeflash.optimization.optimizer import Optimizer


class CodeflashLanguageServerProtocol(LanguageServerProtocol):
_server: CodeflashLanguageServer


InitializationResultT = tuple[bool, CodeOptimizationContext, dict[Path, str]]
WrappedInitializationResultT = Result[InitializationResultT, str]


class CodeflashLanguageServer(LanguageServer):
def __init__(self, name: str, version: str, protocol_cls: type[LanguageServerProtocol]) -> None:
super().__init__(name, version, protocol_cls=protocol_cls)
Expand Down
Loading
Loading