11from __future__ import annotations
22
3- from contextlib import AsyncExitStack
3+ from collections .abc import AsyncGenerator
4+ from contextlib import AsyncExitStack , asynccontextmanager
5+ from functools import cached_property
46from pathlib import Path
7+ from typing import Self , override
58
69import anyio
710import asyncer
11+ import httpx
12+ import loguru
813import uvicorn
914import xxhash
10- from attrs import define , field
11- from litestar import Litestar , post
15+ from attrs import Factory , define , field
16+ from litestar import Litestar , get , post
1217from loguru import logger
13- from lsp_client .jsonrpc .types import RawRequest , RawResponsePackage
18+ from lsp_client import Client
19+ from lsp_client .jsonrpc .types import (
20+ RawNotification ,
21+ RawRequest ,
22+ RawRequestPackage ,
23+ RawResponsePackage ,
24+ )
25+ from lsp_client .server import Server , ServerRuntimeError
26+ from lsp_client .server .types import ServerRequest
27+ from lsp_client .utils .channel import Sender
28+ from lsp_client .utils .workspace import Workspace , format_workspace
1429
1530from lsp_cli .client import TargetClient
16- from lsp_cli .settings import RUNTIME_DIR , settings
31+ from lsp_cli .settings import LOG_DIR , RUNTIME_DIR , settings
1732
18- from .models import ManagedClientInfo
33+ from .models import HandshakeRequest , HandshakeResponse , ManagedClientInfo
34+
35+
36+ @define
37+ class ManagerServer (Server ):
38+ uds_path : Path
39+
40+ @cached_property
41+ def client (self ) -> httpx .AsyncClient :
42+ transport = httpx .AsyncHTTPTransport (uds = self .uds_path .as_posix ())
43+ return httpx .AsyncClient (
44+ transport = transport ,
45+ base_url = "http://localhost" ,
46+ timeout = None ,
47+ )
48+
49+ @override
50+ async def check_availability (self ) -> None :
51+ if not self .uds_path .exists ():
52+ raise ServerRuntimeError (self , f"Server socket not found: { self .uds_path } " )
53+ try :
54+ await self .client .get ("/health" )
55+ except httpx .HTTPError as e :
56+ raise ServerRuntimeError (
57+ self , f"Managed server at { self .uds_path } is not responding: { e } "
58+ ) from e
59+
60+ @override
61+ async def request (self , request : RawRequest ) -> RawResponsePackage :
62+ response = await self .client .post ("/" , json = request )
63+ return response .json ()
64+
65+ @override
66+ async def notify (self , notification : RawNotification ) -> None :
67+ await self .client .post ("/" , json = notification )
68+
69+ @override
70+ async def kill (self ) -> None :
71+ # await self.client.post("/shutdown")
72+ return
73+
74+ @override
75+ @asynccontextmanager
76+ async def run (
77+ self , workspace : Workspace , sender : Sender [ServerRequest ]
78+ ) -> AsyncGenerator [Self ]:
79+ if False :
80+ handshake = HandshakeRequest (
81+ client_id = self .uds_path .stem ,
82+ workspace = {name : folder .uri for name , folder in workspace .items ()},
83+ )
84+ resp = await self .client .post (
85+ "/handshake" , content = handshake .model_dump_json ()
86+ )
87+ json = resp .raise_for_status ().json ()
88+ result = HandshakeResponse .model_validate (json )
89+ if result .status != "ok" :
90+ raise ServerRuntimeError (
91+ self ,
92+ f"Handshake rejected: { result .reason } (current: { result .current_client_id or result .current_workspace } )" ,
93+ )
94+
95+ yield self
1996
2097
2198def get_client_id (target : TargetClient ) -> str :
@@ -30,11 +107,15 @@ class ManagedClient:
30107 _server : uvicorn .Server | None = field (init = False , default = None )
31108 _timeout_scope : anyio .CancelScope | None = field (init = False , default = None )
32109 _server_scope : anyio .CancelScope | None = field (init = False , default = None )
33- _deadline : float = field (init = False )
34- _should_exit : bool = field (init = False , default = False )
35110
36- def __attrs_post_init__ (self ) -> None :
37- self ._deadline = anyio .current_time () + settings .idle_timeout
111+ _deadline : float = Factory (lambda : anyio .current_time () + settings .idle_timeout )
112+ _should_exit : bool = False
113+
114+ _log_handler_id : int | None = field (init = False , default = None )
115+
116+ @cached_property
117+ def logger (self ) -> loguru .Logger :
118+ return logger .bind (client_id = self .id )
38119
39120 @property
40121 def id (self ) -> str :
@@ -53,6 +134,7 @@ def info(self) -> ManagedClientInfo:
53134 )
54135
55136 def stop (self ) -> None :
137+ self .logger .info ("Stopping managed client {}" , self .id )
56138 self ._should_exit = True
57139 if self ._server :
58140 self ._server .should_exit = True
@@ -62,22 +144,86 @@ def stop(self) -> None:
62144 self ._timeout_scope .cancel ()
63145
64146 async def run (self ) -> None :
147+ client_log_dir = LOG_DIR / "clients"
148+ client_log_dir .mkdir (parents = True , exist_ok = True )
149+ self ._log_handler_id = logger .add (
150+ client_log_dir / f"{ self .id } .log" ,
151+ rotation = "10 MB" ,
152+ retention = "1 day" ,
153+ level = "DEBUG" ,
154+ filter = lambda record : record ["extra" ].get ("client_id" ) == self .id ,
155+ )
156+
157+ self .logger .info (
158+ "Starting managed client {} for project {} at {}" ,
159+ self .id ,
160+ self .target .project_path ,
161+ self .uds_path ,
162+ )
65163 async with AsyncExitStack () as stack :
164+ if self ._log_handler_id is not None :
165+ stack .callback (lambda : logger .remove (self ._log_handler_id ))
166+
66167 uds_path = anyio .Path (self .uds_path )
67168 await uds_path .unlink (missing_ok = True )
68169 await uds_path .parent .mkdir (parents = True , exist_ok = True )
69170 stack .push_async_callback (uds_path .unlink , missing_ok = True )
70171
71- lsp_client = await stack .enter_async_context (self .target .client_cls ())
172+ client : Client = await stack .enter_async_context (
173+ self .target .client_cls (workspace = self .target .project_path )
174+ )
175+ self .logger .info ("LSP client for {} initialized" , self .id )
176+
177+ @get ("/health" )
178+ async def health () -> str :
179+ return "ok"
180+
181+ @post ("/handshake" )
182+ async def handshake (data : HandshakeRequest ) -> HandshakeResponse :
183+ self .logger .debug ("Handshake received for client {}" , self .id )
184+ self ._reset_timeout ()
185+ if data .client_id != self .id :
186+ return HandshakeResponse (
187+ status = "error" ,
188+ reason = "id_mismatch" ,
189+ current_client_id = self .id ,
190+ )
191+ received_ws = format_workspace (data .workspace )
192+ if client .get_workspace () != received_ws :
193+ return HandshakeResponse (
194+ status = "error" ,
195+ reason = "workspace_mismatch" ,
196+ current_workspace = {
197+ name : folder .uri
198+ for name , folder in client .get_workspace ().items ()
199+ },
200+ )
201+ return HandshakeResponse (status = "ok" )
202+
203+ @post ("/shutdown" )
204+ async def shutdown () -> None :
205+ self .logger .info ("Shutdown requested for client {}" , self .id )
206+ self .stop ()
72207
73208 @post ("/" )
74- async def handle_request (data : RawRequest ) -> RawResponsePackage :
209+ async def handle_package (
210+ data : RawRequestPackage ,
211+ ) -> RawResponsePackage | None :
75212 self ._reset_timeout ()
76- return await lsp_client .get_server ().request (data )
213+ server = client .get_server ()
214+
215+ match data :
216+ case {"id" : _}:
217+ return await server .request (data )
218+ case _:
219+ await server .notify (data )
77220
78- app = Litestar (route_handlers = [handle_request ])
221+ app = Litestar (route_handlers = [handle_package , health , shutdown ])
79222 config = uvicorn .Config (
80- app , uds = str (self .uds_path ), log_level = "error" , loop = "asyncio"
223+ app ,
224+ uds = str (self .uds_path ),
225+ log_level = "error" ,
226+ loop = "asyncio" ,
81227 )
82228 self ._server = uvicorn .Server (config )
83229
@@ -112,7 +258,6 @@ async def _timeout_loop(self) -> None:
112258 self ._timeout_scope = scope
113259 await anyio .sleep (remaining )
114260
115- logger .info ("Client {} idle timeout, shutting down" , self .id )
116261 if self ._server :
117262 self ._server .should_exit = True
118263 if self ._server_scope :
0 commit comments