|
| 1 | +import sys |
1 | 2 | from contextlib import asynccontextmanager |
2 | 3 | from pathlib import Path |
3 | | -from typing import Any, AsyncGenerator |
| 4 | +from typing import Any, AsyncGenerator, Optional |
4 | 5 |
|
5 | 6 | from fastmcp import FastMCP |
| 7 | +from lsprotocol.types import TextDocumentItem |
6 | 8 |
|
7 | 9 | from codeflash.lsp.server import CodeflashLanguageServer |
8 | 10 | from codeflash.lsp.beta import perform_function_optimization, FunctionOptimizationParams, \ |
9 | | - initialize_function_optimization |
| 11 | + initialize_function_optimization, validate_project, discover_function_tests |
10 | 12 | from tests.scripts.end_to_end_test_utilities import TestConfig, run_codeflash_command |
11 | 13 | from lsprotocol import types |
12 | 14 |
|
| 15 | +# dummy method for getting pyproject.toml path |
| 16 | +def _find_pyproject_toml(workspace_path: str) -> Optional[Path]: |
| 17 | + workspace_path_obj = Path(workspace_path) |
| 18 | + for file_path in workspace_path_obj.rglob("pyproject.toml"): |
| 19 | + return file_path.resolve() |
| 20 | + return None |
| 21 | + |
13 | 22 |
|
14 | 23 | # Define lifespan context manager |
15 | 24 | @asynccontextmanager |
16 | 25 | async def lifespan(mcp: FastMCP) -> AsyncGenerator[None, Any]: |
17 | 26 | print("Starting up...") |
18 | 27 | print(mcp.name) |
19 | | - # Do startup work here (connect to DB, initialize cache, etc.) |
20 | | - server = CodeflashLanguageServer(name = "codeflash", version = "0.0.1") |
21 | | - config_file = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/pyproject.toml") |
22 | | - file = "/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/bubble_sort.py" |
23 | | - function = "sorter" |
24 | | - params = FunctionOptimizationParams(functionName=function, textDocument=types.TextDocumentIdentifier(Path(file).as_uri())) |
25 | | - server.prepare_optimizer_arguments(config_file) |
26 | | - initialize_function_optimization(server, params) |
27 | | - perform_function_optimization(server, params) |
28 | | - #optimize_code(file, function) |
| 28 | + # # Do startup work here (connect to DB, initialize cache, etc.) |
| 29 | + # server = CodeflashLanguageServer(name = "codeflash", version = "0.0.1") |
| 30 | + # config_file = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/pyproject.toml") |
| 31 | + # file = "/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/bubble_sort.py" |
| 32 | + # function = "sorter" |
| 33 | + # params = FunctionOptimizationParams(functionName=function, textDocument=types.TextDocumentIdentifier(Path(file).as_uri())) |
| 34 | + # server.prepare_optimizer_arguments(config_file) |
| 35 | + # initialize_function_optimization(server, params) |
| 36 | + # perform_function_optimization(server, params) |
| 37 | + # #optimize_code(file, function) |
| 38 | + |
| 39 | + #################### initialize the server ############################# |
| 40 | + server = CodeflashLanguageServer("codeflash-language-server", "v1.0") |
| 41 | + # suppose the pyproject.toml is in the current directory |
| 42 | + server.prepare_optimizer_arguments(_find_pyproject_toml(".")) |
| 43 | + result = validate_project(server, None) |
| 44 | + if result["status"] == "error": |
| 45 | + # handle if the project is not valid, it can be because pyproject.toml is not valid or the repository is in bare state or the repository has no commits, which will stop the worktree from working |
| 46 | + print(result["message"]) |
| 47 | + sys.exit(1) |
| 48 | + |
| 49 | + #################### start the optimization for file, function ############################# |
| 50 | + file_path = "/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize/bubble_sort.py" |
| 51 | + function_name = "sorter" |
| 52 | + |
| 53 | + # This is not necessary, just for testing |
| 54 | + server.args.module_root = Path("/Users/codeflash/Downloads/codeflash-dev/codeflash/code_to_optimize") |
| 55 | + result = initialize_function_optimization(server, FunctionOptimizationParams( |
| 56 | + functionName=function_name, |
| 57 | + textDocument=TextDocumentItem( |
| 58 | + uri=file_path, |
| 59 | + language_id="python", |
| 60 | + version=1, |
| 61 | + text="" |
| 62 | + ) |
| 63 | + ) |
| 64 | + ) |
| 65 | + if result["status"] == "error": |
| 66 | + # handle if the function is not optimizable |
| 67 | + print(result["message"]) |
| 68 | + sys.exit(1) |
| 69 | + |
| 70 | + discover_function_tests(server, FunctionOptimizationParams(functionName=function_name, textDocument=None)) |
| 71 | + final_result = perform_function_optimization(server, FunctionOptimizationParams(functionName=function_name, |
| 72 | + textDocument=None)) |
| 73 | + if final_result["status"] == "success": |
| 74 | + print(final_result) |
29 | 75 | yield |
30 | 76 | # Cleanup work after shutdown |
31 | 77 | print("Shutting down...") |
|
0 commit comments