|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import platform |
3 | 4 | from collections.abc import AsyncGenerator |
4 | 5 | from contextlib import asynccontextmanager |
5 | | -from pathlib import Path |
6 | 6 | from typing import final, override |
7 | 7 |
|
8 | 8 | import anyio |
9 | | -from anyio.abc import ByteStream |
| 9 | +import tenacity |
| 10 | +from anyio.abc import ByteStream, IPAddressType |
10 | 11 | from anyio.streams.buffered import BufferedByteReceiveStream |
11 | 12 | from attrs import define, field |
12 | 13 | from loguru import logger |
13 | | -from tenacity import AsyncRetrying, stop_after_delay, wait_exponential |
14 | 14 |
|
15 | 15 | from lsp_client.jsonrpc.parse import read_raw_package, write_raw_package |
16 | 16 | from lsp_client.jsonrpc.types import RawPackage |
| 17 | +from lsp_client.server.exception import ServerRuntimeError |
| 18 | +from lsp_client.utils.types import AnyPath |
17 | 19 | from lsp_client.utils.workspace import Workspace |
18 | 20 |
|
19 | 21 | from .abc import Server |
20 | 22 |
|
| 23 | +type TCPSocket = tuple[IPAddressType, int] |
| 24 | +"""(host, port)""" |
| 25 | + |
| 26 | +type UnixSocket = AnyPath |
| 27 | + |
21 | 28 |
|
22 | 29 | @final |
23 | 30 | @define |
24 | 31 | class SocketServer(Server): |
25 | 32 | """Runtime for socket backend, e.g. connecting to a remote LSP server via TCP or Unix socket.""" |
26 | 33 |
|
27 | | - host: str | None = None |
28 | | - """The host to connect to (TCP only).""" |
29 | | - port: int | None = None |
30 | | - """The port to connect to (TCP only).""" |
31 | | - path: Path | str | None = None |
32 | | - """The path to the Unix socket (Unix only).""" |
| 34 | + connection: TCPSocket | UnixSocket |
| 35 | + """Connection information, either (host, port) for TCP or path for Unix socket.""" |
| 36 | + |
33 | 37 | timeout: float = 10.0 |
34 | 38 | """Timeout for connecting to the socket.""" |
35 | 39 |
|
36 | | - _stream: ByteStream | None = field(init=False, default=None) |
37 | | - _buffered: BufferedByteReceiveStream | None = field(init=False, default=None) |
| 40 | + _stream: ByteStream = field(init=False) |
| 41 | + _buffered: BufferedByteReceiveStream = field(init=False) |
| 42 | + |
| 43 | + @tenacity.retry( |
| 44 | + stop=tenacity.stop_after_delay(10), |
| 45 | + wait=tenacity.wait_exponential(multiplier=0.1, max=1), |
| 46 | + reraise=True, |
| 47 | + ) |
| 48 | + async def connect(self) -> ByteStream: |
| 49 | + match self.connection: |
| 50 | + case (host, port): |
| 51 | + logger.debug("Connecting to {}:{}", host, port) |
| 52 | + return await anyio.connect_tcp(host, port) |
| 53 | + case path: |
| 54 | + if platform.platform().startswith("Windows"): |
| 55 | + raise ServerRuntimeError( |
| 56 | + self, "Unix sockets are not supported on Windows" |
| 57 | + ) |
| 58 | + logger.debug("Connecting to {}", path) |
| 59 | + return await anyio.connect_unix(str(path)) |
38 | 60 |
|
39 | 61 | @override |
40 | 62 | async def check_availability(self) -> None: |
41 | | - if self.host is None and self.port is None and self.path is None: |
42 | | - raise ValueError( |
43 | | - "Either host and port (for TCP), or path (for Unix socket) must be provided" |
44 | | - ) |
| 63 | + try: |
| 64 | + stream = await self.connect() |
| 65 | + await stream.aclose() |
| 66 | + except anyio.ConnectionFailed as e: |
| 67 | + raise ServerRuntimeError(self, f"Failed to connect to socket: {e}") from e |
45 | 68 |
|
46 | 69 | @override |
47 | 70 | async def send(self, package: RawPackage) -> None: |
48 | 71 | if self._stream is None: |
49 | | - raise RuntimeError( |
50 | | - "SocketServer is not running. Use 'async with server.run(...)'" |
51 | | - ) |
| 72 | + raise RuntimeError("SocketServer is not running") |
52 | 73 | await write_raw_package(self._stream, package) |
53 | 74 |
|
54 | 75 | @override |
55 | 76 | async def receive(self) -> RawPackage | None: |
56 | 77 | if self._buffered is None: |
57 | | - raise RuntimeError( |
58 | | - "SocketServer is not running. Use 'async with server.run(...)'" |
59 | | - ) |
| 78 | + raise RuntimeError("SocketServer is not running") |
60 | 79 | try: |
61 | 80 | return await read_raw_package(self._buffered) |
62 | 81 | except (anyio.EndOfStream, anyio.IncompleteRead, anyio.ClosedResourceError): |
63 | 82 | logger.debug("Socket closed") |
64 | | - return None |
| 83 | + return |
65 | 84 |
|
66 | 85 | @override |
67 | 86 | async def kill(self) -> None: |
68 | | - if self._stream: |
69 | | - await self._stream.aclose() |
| 87 | + await self._stream.aclose() |
70 | 88 |
|
71 | 89 | @override |
72 | 90 | @asynccontextmanager |
73 | 91 | async def run_process(self, workspace: Workspace) -> AsyncGenerator[None]: |
74 | 92 | await self.check_availability() |
75 | 93 |
|
76 | | - async def connect() -> ByteStream: |
77 | | - if self.host is not None and self.port is not None: |
78 | | - logger.debug("Connecting to {}:{}", self.host, self.port) |
79 | | - return await anyio.connect_tcp(self.host, self.port) |
80 | | - if self.path is not None: |
81 | | - if not hasattr(anyio, "connect_unix"): |
82 | | - raise RuntimeError( |
83 | | - "Unix sockets are not supported on this platform" |
84 | | - ) |
85 | | - logger.debug("Connecting to {}", self.path) |
86 | | - return await anyio.connect_unix(str(self.path)) |
87 | | - raise ValueError("Either host and port, or path must be provided") |
88 | | - |
89 | | - stream: ByteStream | None = None |
90 | | - async for attempt in AsyncRetrying( |
91 | | - stop=stop_after_delay(self.timeout), |
92 | | - wait=wait_exponential(multiplier=0.1, max=1), |
93 | | - reraise=True, |
94 | | - ): |
95 | | - with attempt: |
96 | | - stream = await connect() |
97 | | - |
98 | | - if stream is None: |
99 | | - raise RuntimeError("Failed to connect to socket") |
| 94 | + stream: ByteStream = await self.connect() |
| 95 | + |
| 96 | + self._stream = stream |
| 97 | + self._buffered = BufferedByteReceiveStream(stream) |
100 | 98 |
|
101 | 99 | async with stream: |
102 | | - self._stream = stream |
103 | | - self._buffered = BufferedByteReceiveStream(stream) |
104 | | - try: |
105 | | - yield |
106 | | - finally: |
107 | | - self._stream = None |
108 | | - self._buffered = None |
| 100 | + yield |
0 commit comments