Skip to content

Commit f569528

Browse files
committed
feat: update socket server
1 parent d467041 commit f569528

File tree

1 file changed

+47
-55
lines changed

1 file changed

+47
-55
lines changed

src/lsp_client/server/socket.py

Lines changed: 47 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,100 @@
11
from __future__ import annotations
22

3+
import platform
34
from collections.abc import AsyncGenerator
45
from contextlib import asynccontextmanager
5-
from pathlib import Path
66
from typing import final, override
77

88
import anyio
9-
from anyio.abc import ByteStream
9+
import tenacity
10+
from anyio.abc import ByteStream, IPAddressType
1011
from anyio.streams.buffered import BufferedByteReceiveStream
1112
from attrs import define, field
1213
from loguru import logger
13-
from tenacity import AsyncRetrying, stop_after_delay, wait_exponential
1414

1515
from lsp_client.jsonrpc.parse import read_raw_package, write_raw_package
1616
from lsp_client.jsonrpc.types import RawPackage
17+
from lsp_client.server.exception import ServerRuntimeError
18+
from lsp_client.utils.types import AnyPath
1719
from lsp_client.utils.workspace import Workspace
1820

1921
from .abc import Server
2022

23+
type TCPSocket = tuple[IPAddressType, int]
24+
"""(host, port)"""
25+
26+
type UnixSocket = AnyPath
27+
2128

2229
@final
2330
@define
2431
class SocketServer(Server):
2532
"""Runtime for socket backend, e.g. connecting to a remote LSP server via TCP or Unix socket."""
2633

27-
host: str | None = None
28-
"""The host to connect to (TCP only)."""
29-
port: int | None = None
30-
"""The port to connect to (TCP only)."""
31-
path: Path | str | None = None
32-
"""The path to the Unix socket (Unix only)."""
34+
connection: TCPSocket | UnixSocket
35+
"""Connection information, either (host, port) for TCP or path for Unix socket."""
36+
3337
timeout: float = 10.0
3438
"""Timeout for connecting to the socket."""
3539

36-
_stream: ByteStream | None = field(init=False, default=None)
37-
_buffered: BufferedByteReceiveStream | None = field(init=False, default=None)
40+
_stream: ByteStream = field(init=False)
41+
_buffered: BufferedByteReceiveStream = field(init=False)
42+
43+
@tenacity.retry(
44+
stop=tenacity.stop_after_delay(10),
45+
wait=tenacity.wait_exponential(multiplier=0.1, max=1),
46+
reraise=True,
47+
)
48+
async def connect(self) -> ByteStream:
49+
match self.connection:
50+
case (host, port):
51+
logger.debug("Connecting to {}:{}", host, port)
52+
return await anyio.connect_tcp(host, port)
53+
case path:
54+
if platform.platform().startswith("Windows"):
55+
raise ServerRuntimeError(
56+
self, "Unix sockets are not supported on Windows"
57+
)
58+
logger.debug("Connecting to {}", path)
59+
return await anyio.connect_unix(str(path))
3860

3961
@override
4062
async def check_availability(self) -> None:
41-
if self.host is None and self.port is None and self.path is None:
42-
raise ValueError(
43-
"Either host and port (for TCP), or path (for Unix socket) must be provided"
44-
)
63+
try:
64+
stream = await self.connect()
65+
await stream.aclose()
66+
except anyio.ConnectionFailed as e:
67+
raise ServerRuntimeError(self, f"Failed to connect to socket: {e}") from e
4568

4669
@override
4770
async def send(self, package: RawPackage) -> None:
4871
if self._stream is None:
49-
raise RuntimeError(
50-
"SocketServer is not running. Use 'async with server.run(...)'"
51-
)
72+
raise RuntimeError("SocketServer is not running")
5273
await write_raw_package(self._stream, package)
5374

5475
@override
5576
async def receive(self) -> RawPackage | None:
5677
if self._buffered is None:
57-
raise RuntimeError(
58-
"SocketServer is not running. Use 'async with server.run(...)'"
59-
)
78+
raise RuntimeError("SocketServer is not running")
6079
try:
6180
return await read_raw_package(self._buffered)
6281
except (anyio.EndOfStream, anyio.IncompleteRead, anyio.ClosedResourceError):
6382
logger.debug("Socket closed")
64-
return None
83+
return
6584

6685
@override
6786
async def kill(self) -> None:
68-
if self._stream:
69-
await self._stream.aclose()
87+
await self._stream.aclose()
7088

7189
@override
7290
@asynccontextmanager
7391
async def run_process(self, workspace: Workspace) -> AsyncGenerator[None]:
7492
await self.check_availability()
7593

76-
async def connect() -> ByteStream:
77-
if self.host is not None and self.port is not None:
78-
logger.debug("Connecting to {}:{}", self.host, self.port)
79-
return await anyio.connect_tcp(self.host, self.port)
80-
if self.path is not None:
81-
if not hasattr(anyio, "connect_unix"):
82-
raise RuntimeError(
83-
"Unix sockets are not supported on this platform"
84-
)
85-
logger.debug("Connecting to {}", self.path)
86-
return await anyio.connect_unix(str(self.path))
87-
raise ValueError("Either host and port, or path must be provided")
88-
89-
stream: ByteStream | None = None
90-
async for attempt in AsyncRetrying(
91-
stop=stop_after_delay(self.timeout),
92-
wait=wait_exponential(multiplier=0.1, max=1),
93-
reraise=True,
94-
):
95-
with attempt:
96-
stream = await connect()
97-
98-
if stream is None:
99-
raise RuntimeError("Failed to connect to socket")
94+
stream: ByteStream = await self.connect()
95+
96+
self._stream = stream
97+
self._buffered = BufferedByteReceiveStream(stream)
10098

10199
async with stream:
102-
self._stream = stream
103-
self._buffered = BufferedByteReceiveStream(stream)
104-
try:
105-
yield
106-
finally:
107-
self._stream = None
108-
self._buffered = None
100+
yield

0 commit comments

Comments
 (0)