Skip to content

Commit 2ed03c3

Browse files
committed
feat: add more reasonable default server startup
1 parent 319e3c3 commit 2ed03c3

File tree

1 file changed

+51
-45
lines changed

1 file changed

+51
-45
lines changed

src/lsp_client/client/abc.py

Lines changed: 51 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import os
44
from abc import ABC, abstractmethod
5-
from collections.abc import AsyncGenerator
5+
from collections.abc import AsyncGenerator, Iterable
66
from contextlib import asynccontextmanager
77
from pathlib import Path
88
from typing import Any, Self, override
@@ -30,75 +30,83 @@
3030
CapabilityClientProtocol,
3131
CapabilityProtocol,
3232
)
33-
from lsp_client.server import LocalServer
34-
from lsp_client.server.abc import LSPServer
33+
from lsp_client.server import DefaultServers, Server, ServerRuntimeError
3534
from lsp_client.server.types import ServerRequest
3635
from lsp_client.utils.channel import Receiver, channel
3736
from lsp_client.utils.types import AnyPath, Notification, Request, Response, lsp_type
3837
from lsp_client.utils.workspace import (
3938
DEFAULT_WORKSPACE_DIR,
4039
RawWorkspace,
4140
Workspace,
42-
WorkspaceFolder,
41+
format_workspace,
4342
)
4443

4544

4645
@define
47-
class LSPClient(
46+
class Client(
4847
# text sync support is mandatory
4948
WithNotifyTextDocumentSynchronize,
5049
CapabilityClientProtocol,
5150
AsyncContextManagerMixin,
5251
ABC,
5352
):
54-
_server: LSPServer | None = field(alias="server", default=None)
55-
_workspace: RawWorkspace = field(alias="workspace", factory=Path.cwd)
53+
_server_arg: Server | None = field(alias="server", default=None)
54+
_workspace_arg: RawWorkspace = field(alias="workspace", factory=Path.cwd)
5655

5756
sync_file: bool = True
5857
request_timeout: float = 5.0
5958

59+
_server: Server = field(init=False)
60+
_workspace: Workspace = field(init=False)
6061
_buffer: LSPFileBuffer = field(factory=LSPFileBuffer, init=False)
6162

62-
def get_server(self) -> LSPServer:
63-
return self._server or self.create_default_server()
63+
def _iter_candidate_servers(self) -> Iterable[Server]:
64+
"""
65+
Server candidates in order of priority:
66+
1. User-provided server
67+
2. Containerized server
68+
3. Local server (maybe with auto-installation)
69+
"""
70+
71+
if self._server_arg:
72+
yield self._server_arg
73+
defaults = self.create_default_servers()
74+
yield defaults.container
75+
yield defaults.local
76+
77+
@asynccontextmanager
78+
async def _run_server(
79+
self,
80+
) -> AsyncGenerator[tuple[Server, Receiver[ServerRequest]]]:
81+
async with channel[ServerRequest].create() as (sender, receiver):
82+
errors: list[ServerRuntimeError] = []
83+
for server in self._iter_candidate_servers():
84+
try:
85+
async with server.run(self.get_workspace(), sender=sender) as s: # ty: ignore[invalid-argument-type]
86+
yield s, receiver
87+
return
88+
except ServerRuntimeError as e:
89+
logger.debug("Failed to start server {}: {}", server, e)
90+
errors.append(e)
91+
92+
raise ExceptionGroup(
93+
f"All servers failed to start for {type(self).__name__}", errors
94+
)
6495

6596
@override
6697
def get_workspace(self) -> Workspace:
67-
match self._workspace:
68-
case str() | os.PathLike() as root_folder_path:
69-
return Workspace(
70-
{
71-
DEFAULT_WORKSPACE_DIR: WorkspaceFolder(
72-
uri=Path(root_folder_path).as_uri(),
73-
name="root",
74-
)
75-
}
76-
)
77-
case Workspace() as ws:
78-
return ws
79-
case _ as mapping:
80-
return Workspace(
81-
{
82-
name: WorkspaceFolder(uri=Path(path).as_uri(), name=name)
83-
for name, path in mapping.items()
84-
}
85-
)
98+
return self._workspace
99+
100+
def get_server(self) -> Server:
101+
return self._server
86102

87103
@abstractmethod
88104
def get_language_id(self) -> lsp_type.LanguageKind:
89105
"""The language ID of the client."""
90106

91107
@abstractmethod
92-
def create_default_server(self) -> LSPServer:
93-
"""Create the default server for this client."""
94-
95-
@abstractmethod
96-
async def ensure_installed(self) -> None:
97-
"""
98-
Check and install the server if necessary.
99-
100-
Note: For local runtime only.
101-
"""
108+
def create_default_servers(self) -> DefaultServers:
109+
"""Create default servers for this client."""
102110

103111
@abstractmethod
104112
def create_initialization_options(self) -> dict[str, Any]:
@@ -230,21 +238,19 @@ async def _exit(self) -> None:
230238
@asynccontextmanager
231239
@logger.catch(reraise=True)
232240
async def __asynccontextmanager__(self) -> AsyncGenerator[Self]:
233-
if isinstance(self._server, LocalServer):
234-
await self.ensure_installed()
235-
236-
self._hook = build_server_request_hooks(self)
237-
client_capabilities = build_client_capabilities(self.__class__)
241+
self._workspace = format_workspace(self._workspace_arg)
238242

239243
async with (
240244
asyncer.create_task_group() as tg,
241-
channel[ServerRequest].create() as (sender, receiver),
242-
self.get_server().serve(workspace=self.get_workspace(), sender=sender), # ty: ignore[invalid-argument-type]
245+
self._run_server() as (server, receiver), # ty: ignore[invalid-argument-type]
243246
):
247+
self._server = server
248+
244249
# start to receive server requests here,
245250
# since server notification can be sent before `initialize`
246251
tg.soonify(self._dispatch_server_requests)(receiver) # ty: ignore[invalid-argument-type]
247252

253+
client_capabilities = build_client_capabilities(self.__class__)
248254
root_workspace = self.get_workspace().get(DEFAULT_WORKSPACE_DIR)
249255
root_path = root_workspace.path.as_posix() if root_workspace else None
250256
root_uri = root_workspace.uri if root_workspace else None

0 commit comments

Comments
 (0)