Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions codeflash/code_utils/code_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,7 @@ def exit_with_message(message: str, *, error_on_exit: bool = False) -> None:
paneled_text(message, panel_args={"style": "red"})

sys.exit(1 if error_on_exit else 0)


async def dummy_async_function() -> None:
"""Provide a dummy async function for testing purposes."""
15 changes: 11 additions & 4 deletions codeflash/code_utils/static_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

from pydantic import BaseModel, ConfigDict, field_validator

from codeflash.models.models import FunctionParent

if TYPE_CHECKING:
from codeflash.models.models import FunctionParent

Expand Down Expand Up @@ -139,14 +141,19 @@ def get_first_top_level_function_or_method_ast(


def function_kind(node: ast.FunctionDef | ast.AsyncFunctionDef, parents: list[FunctionParent]) -> FunctionKind | None:
if not parents or parents[0].type in ["FunctionDef", "AsyncFunctionDef"]:
if not parents:
return FunctionKind.FUNCTION
parent_type = parents[0].type
if parent_type in {"FunctionDef", "AsyncFunctionDef"}:
return FunctionKind.FUNCTION
if parents[0].type == "ClassDef":
if parent_type == "ClassDef":
# Search for known decorators efficiently
for decorator in node.decorator_list:
if isinstance(decorator, ast.Name):
if decorator.id == "classmethod":
dec_id = decorator.id
if dec_id == "classmethod":
return FunctionKind.CLASS_METHOD
if decorator.id == "staticmethod":
if dec_id == "staticmethod":
return FunctionKind.STATIC_METHOD
return FunctionKind.INSTANCE_METHOD
return None
Expand Down
9 changes: 9 additions & 0 deletions codeflash/discovery/functions_to_optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ def qualified_name(self) -> str:
def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
return f"{module_name_from_file_path(self.file_path, project_root_path)}.{self.qualified_name}"

@property
def server_info(self) -> dict[str, str | int]:
return {
"file_path": str(self.file_path),
"function_name": self.function_name,
"starting_line": self.starting_line,
"ending_line": self.ending_line,
}


def get_functions_to_optimize(
optimize_all: str | None,
Expand Down
Empty file added codeflash/lsp/__init__.py
Empty file.
105 changes: 105 additions & 0 deletions codeflash/lsp/beta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING

from pygls import uris

from codeflash.either import is_successful
from codeflash.lsp.server import CodeflashLanguageServer, CodeflashLanguageServerProtocol

if TYPE_CHECKING:
from lsprotocol import types


@dataclass
class OptimizableFunctionsParams:
textDocument: types.TextDocumentIdentifier # noqa: N815


@dataclass
class OptimizeFunctionParams:
textDocument: types.TextDocumentIdentifier # noqa: N815
functionName: str # noqa: N815


server = CodeflashLanguageServer("codeflash-language-server", "v1.0", protocol_cls=CodeflashLanguageServerProtocol)


@server.feature("getOptimizableFunctions")
def get_optimizable_functions(
server: CodeflashLanguageServer, params: OptimizableFunctionsParams
) -> dict[str, list[str]]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.optimizer.args.file = file_path
server.optimizer.args.previous_checkpoint_functions = False
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
path_to_qualified_names = {}
for path, functions in optimizable_funcs.items():
path_to_qualified_names[path.as_posix()] = [func.qualified_name for func in functions]
return path_to_qualified_names


@server.feature("optimizeFunction")
def optimize_function(server: CodeflashLanguageServer, params: OptimizeFunctionParams) -> dict[str, str]:
file_path = Path(uris.to_fs_path(params.textDocument.uri))
server.optimizer.args.function = params.functionName
server.optimizer.args.file = file_path
optimizable_funcs, _ = server.optimizer.get_optimizable_functions()
if not optimizable_funcs:
return {"functionName": params.functionName, "status": "not found", "args": None}
fto = optimizable_funcs.popitem()[1][0]
server.optimizer.current_function_being_optimized = fto
return {"functionName": params.functionName, "status": "success", "info": fto.server_info}


@server.feature("second_step_in_optimize_function")
def second_step_in_optimize_function(server: CodeflashLanguageServer, params: OptimizeFunctionParams) -> dict[str, str]:
current_function = server.optimizer.current_function_being_optimized

optimizable_funcs = {current_function.file_path: [current_function]}

function_to_tests, num_discovered_tests = server.optimizer.discover_tests(optimizable_funcs)
# mocking in order to get things going
return {"functionName": params.functionName, "status": "success", "generated_tests": str(num_discovered_tests)}


@server.feature("third_step_in_optimize_function")
def third_step_in_optimize_function(server: CodeflashLanguageServer, params: OptimizeFunctionParams) -> dict[str, str]:
current_function = server.optimizer.current_function_being_optimized

module_prep_result = server.optimizer.prepare_module_for_optimization(current_function.file_path)

validated_original_code, original_module_ast = module_prep_result

function_optimizer = server.optimizer.create_function_optimizer(
current_function,
function_to_optimize_source_code=validated_original_code[current_function.file_path].source_code,
original_module_ast=original_module_ast,
original_module_path=current_function.file_path,
)

server.optimizer.current_function_optimizer = function_optimizer
if not function_optimizer:
return {"functionName": params.functionName, "status": "error", "message": "No function optimizer found"}

initialization_result = function_optimizer.can_be_optimized()
if not is_successful(initialization_result):
return {"functionName": params.functionName, "status": "error", "message": initialization_result.failure()}

should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()

return {
"functionName": params.functionName,
"status": "success",
"message": "Function can be optimized",
"extra": original_helper_code,
}


if __name__ == "__main__":
from codeflash.cli_cmds.console import console

console.quiet = True
server.start_io()
56 changes: 56 additions & 0 deletions codeflash/lsp/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

from lsprotocol.types import INITIALIZE
from pygls import uris
from pygls.protocol import LanguageServerProtocol, lsp_method
from pygls.server import LanguageServer

if TYPE_CHECKING:
from lsprotocol.types import InitializeParams, InitializeResult


class CodeflashLanguageServerProtocol(LanguageServerProtocol):
_server: CodeflashLanguageServer

@lsp_method(INITIALIZE)
def lsp_initialize(self, params: InitializeParams) -> InitializeResult:
server = self._server
initialize_result: InitializeResult = super().lsp_initialize(params)

workspace_uri = params.root_uri
if workspace_uri:
workspace_path = uris.to_fs_path(workspace_uri)
pyproject_toml_path = self._find_pyproject_toml(workspace_path)
if pyproject_toml_path:
server.initialize_optimizer(pyproject_toml_path)
server.show_message(f"Found pyproject.toml at: {pyproject_toml_path}")
else:
server.show_message("No pyproject.toml found in workspace.")
else:
server.show_message("No workspace URI provided.")

return initialize_result

def _find_pyproject_toml(self, workspace_path: str) -> Path | None:
workspace_path_obj = Path(workspace_path)
for file_path in workspace_path_obj.rglob("pyproject.toml"):
return file_path.resolve()
return None


class CodeflashLanguageServer(LanguageServer):
def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
super().__init__(*args, **kwargs)
self.optimizer = None

def initialize_optimizer(self, config_file: Path) -> None:
from codeflash.cli_cmds.cli import parse_args, process_pyproject_config
from codeflash.optimization.optimizer import Optimizer

args = parse_args()
args.config_file = config_file
args = process_pyproject_config(args)
self.optimizer = Optimizer(args)
11 changes: 10 additions & 1 deletion codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def __init__(
self.total_benchmark_timings = total_benchmark_timings if total_benchmark_timings else {}
self.replay_tests_dir = replay_tests_dir if replay_tests_dir else None

def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
def can_be_optimized(self) -> Result[tuple[bool, CodeOptimizationContext, dict[Path, str]], str]:
should_run_experiment = self.experiment_id is not None
logger.debug(f"Function Trace ID: {self.function_trace_id}")
ph("cli-optimize-function-start", {"function_trace_id": self.function_trace_id})
Expand All @@ -171,6 +171,15 @@ def optimize_function(self) -> Result[BestOptimization, str]: # noqa: PLR0911
):
return Failure("Function optimization previously attempted, skipping.")

return Success((should_run_experiment, code_context, original_helper_code))

def optimize_function(self) -> Result[BestOptimization, str]:
initialization_result = self.can_be_optimized()
if not is_successful(initialization_result):
return Failure(initialization_result.failure())

should_run_experiment, code_context, original_helper_code = initialization_result.unwrap()

code_print(code_context.read_writable_code)
generated_test_paths = [
get_test_file_path(
Expand Down
Loading
Loading