|
2 | 2 |
|
3 | 3 | import os |
4 | 4 | from abc import ABC, abstractmethod |
5 | | -from collections.abc import AsyncGenerator |
| 5 | +from collections.abc import AsyncGenerator, Iterable |
6 | 6 | from contextlib import asynccontextmanager |
7 | 7 | from pathlib import Path |
8 | 8 | from typing import Any, Self, override |
|
30 | 30 | CapabilityClientProtocol, |
31 | 31 | CapabilityProtocol, |
32 | 32 | ) |
33 | | -from lsp_client.server import LocalServer |
34 | | -from lsp_client.server.abc import LSPServer |
| 33 | +from lsp_client.server import DefaultServers, Server, ServerRuntimeError |
35 | 34 | from lsp_client.server.types import ServerRequest |
36 | 35 | from lsp_client.utils.channel import Receiver, channel |
37 | 36 | from lsp_client.utils.types import AnyPath, Notification, Request, Response, lsp_type |
38 | 37 | from lsp_client.utils.workspace import ( |
39 | 38 | DEFAULT_WORKSPACE_DIR, |
40 | 39 | RawWorkspace, |
41 | 40 | Workspace, |
42 | | - WorkspaceFolder, |
| 41 | + format_workspace, |
43 | 42 | ) |
44 | 43 |
|
45 | 44 |
|
46 | 45 | @define |
47 | | -class LSPClient( |
| 46 | +class Client( |
48 | 47 | # text sync support is mandatory |
49 | 48 | WithNotifyTextDocumentSynchronize, |
50 | 49 | CapabilityClientProtocol, |
51 | 50 | AsyncContextManagerMixin, |
52 | 51 | ABC, |
53 | 52 | ): |
54 | | - _server: LSPServer | None = field(alias="server", default=None) |
55 | | - _workspace: RawWorkspace = field(alias="workspace", factory=Path.cwd) |
| 53 | + _server_arg: Server | None = field(alias="server", default=None) |
| 54 | + _workspace_arg: RawWorkspace = field(alias="workspace", factory=Path.cwd) |
56 | 55 |
|
57 | 56 | sync_file: bool = True |
58 | 57 | request_timeout: float = 5.0 |
59 | 58 |
|
| 59 | + _server: Server = field(init=False) |
| 60 | + _workspace: Workspace = field(init=False) |
60 | 61 | _buffer: LSPFileBuffer = field(factory=LSPFileBuffer, init=False) |
61 | 62 |
|
62 | | - def get_server(self) -> LSPServer: |
63 | | - return self._server or self.create_default_server() |
| 63 | + def _iter_candidate_servers(self) -> Iterable[Server]: |
| 64 | + """ |
| 65 | + Server candidates in order of priority: |
| 66 | + 1. User-provided server |
| 67 | + 2. Containerized server |
| 68 | + 3. Local server (maybe with auto-installation) |
| 69 | + """ |
| 70 | + |
| 71 | + if self._server_arg: |
| 72 | + yield self._server_arg |
| 73 | + defaults = self.create_default_servers() |
| 74 | + yield defaults.container |
| 75 | + yield defaults.local |
| 76 | + |
| 77 | + @asynccontextmanager |
| 78 | + async def _run_server( |
| 79 | + self, |
| 80 | + ) -> AsyncGenerator[tuple[Server, Receiver[ServerRequest]]]: |
| 81 | + async with channel[ServerRequest].create() as (sender, receiver): |
| 82 | + errors: list[ServerRuntimeError] = [] |
| 83 | + for server in self._iter_candidate_servers(): |
| 84 | + try: |
| 85 | + async with server.run(self.get_workspace(), sender=sender) as s: # ty: ignore[invalid-argument-type] |
| 86 | + yield s, receiver |
| 87 | + return |
| 88 | + except ServerRuntimeError as e: |
| 89 | + logger.debug("Failed to start server {}: {}", server, e) |
| 90 | + errors.append(e) |
| 91 | + |
| 92 | + raise ExceptionGroup( |
| 93 | + f"All servers failed to start for {type(self).__name__}", errors |
| 94 | + ) |
64 | 95 |
|
65 | 96 | @override |
66 | 97 | def get_workspace(self) -> Workspace: |
67 | | - match self._workspace: |
68 | | - case str() | os.PathLike() as root_folder_path: |
69 | | - return Workspace( |
70 | | - { |
71 | | - DEFAULT_WORKSPACE_DIR: WorkspaceFolder( |
72 | | - uri=Path(root_folder_path).as_uri(), |
73 | | - name="root", |
74 | | - ) |
75 | | - } |
76 | | - ) |
77 | | - case Workspace() as ws: |
78 | | - return ws |
79 | | - case _ as mapping: |
80 | | - return Workspace( |
81 | | - { |
82 | | - name: WorkspaceFolder(uri=Path(path).as_uri(), name=name) |
83 | | - for name, path in mapping.items() |
84 | | - } |
85 | | - ) |
| 98 | + return self._workspace |
| 99 | + |
| 100 | + def get_server(self) -> Server: |
| 101 | + return self._server |
86 | 102 |
|
87 | 103 | @abstractmethod |
88 | 104 | def get_language_id(self) -> lsp_type.LanguageKind: |
89 | 105 | """The language ID of the client.""" |
90 | 106 |
|
91 | 107 | @abstractmethod |
92 | | - def create_default_server(self) -> LSPServer: |
93 | | - """Create the default server for this client.""" |
94 | | - |
95 | | - @abstractmethod |
96 | | - async def ensure_installed(self) -> None: |
97 | | - """ |
98 | | - Check and install the server if necessary. |
99 | | -
|
100 | | - Note: For local runtime only. |
101 | | - """ |
| 108 | + def create_default_servers(self) -> DefaultServers: |
| 109 | + """Create default servers for this client.""" |
102 | 110 |
|
103 | 111 | @abstractmethod |
104 | 112 | def create_initialization_options(self) -> dict[str, Any]: |
@@ -230,21 +238,19 @@ async def _exit(self) -> None: |
230 | 238 | @asynccontextmanager |
231 | 239 | @logger.catch(reraise=True) |
232 | 240 | async def __asynccontextmanager__(self) -> AsyncGenerator[Self]: |
233 | | - if isinstance(self._server, LocalServer): |
234 | | - await self.ensure_installed() |
235 | | - |
236 | | - self._hook = build_server_request_hooks(self) |
237 | | - client_capabilities = build_client_capabilities(self.__class__) |
| 241 | + self._workspace = format_workspace(self._workspace_arg) |
238 | 242 |
|
239 | 243 | async with ( |
240 | 244 | asyncer.create_task_group() as tg, |
241 | | - channel[ServerRequest].create() as (sender, receiver), |
242 | | - self.get_server().serve(workspace=self.get_workspace(), sender=sender), |
| 245 | + self._run_server() as (server, receiver), # ty: ignore[invalid-argument-type] |
243 | 246 | ): |
| 247 | + self._server = server |
| 248 | + |
244 | 249 | # start to receive server requests here, |
245 | 250 | # since server notification can be sent before `initialize` |
246 | | - tg.soonify(self._dispatch_server_requests)(receiver) |
| 251 | + tg.soonify(self._dispatch_server_requests)(receiver) # ty: ignore[invalid-argument-type] |
247 | 252 |
|
| 253 | + client_capabilities = build_client_capabilities(self.__class__) |
248 | 254 | root_workspace = self.get_workspace().get(DEFAULT_WORKSPACE_DIR) |
249 | 255 | root_path = root_workspace.path.as_posix() if root_workspace else None |
250 | 256 | root_uri = root_workspace.uri if root_workspace else None |
|
0 commit comments