55
66if collab_endpoint .enable ():
77 from asyncio import create_task
8+ from anyio import get_cancelled_exc_class , Lock
9+ from contextlib import asynccontextmanager
810 from fastapi import WebSocket
9- from pycrdt_websocket import WebsocketServer
10- from pycrdt_websocket .websocket import HttpxWebsocket
11+ from httpx_ws import aconnect_ws
12+ from pycrdt import Array , Doc , Provider
13+ from pycrdt .websocket import WebsocketServer
14+ from pycrdt .websocket .websocket import HttpxWebsocket
1115 from .routes import router
1216
1317
18+ class Websocket :
19+ def __init__ (self , websocket , path : str ):
20+ self ._websocket = websocket
21+ self ._path = path
22+ self ._send_lock = Lock ()
23+
24+ @property
25+ def path (self ) -> str :
26+ return self ._path
27+
28+ def __aiter__ (self ):
29+ return self
30+
31+ async def __anext__ (self ) -> bytes :
32+ try :
33+ message = await self .recv ()
34+ except Exception :
35+ raise StopAsyncIteration ()
36+ return message
37+
38+ async def send (self , message : bytes ):
39+ async with self ._send_lock :
40+ await self ._websocket .send_bytes (message )
41+
42+ async def recv (self ) -> bytes :
43+ b = await self ._websocket .receive_bytes ()
44+ return bytes (b )
45+
46+
47+ @asynccontextmanager
48+ async def aprovider_factory (port , room_name , ydoc = None , log = None ):
49+ ydoc = Doc () if ydoc is None else ydoc
50+ server_websocket = None
51+ connect = aconnect_ws (f"http://localhost:{ port } /{ room_name } " )
52+ try :
53+ async with connect as websocket :
54+ websocket_provider = Provider (ydoc , Websocket (websocket , room_name ), log )
55+ async with websocket_provider as websocket_provider :
56+ yield ydoc , server_websocket
57+ except get_cancelled_exc_class ():
58+ pass
59+
60+
61+ def provider_factory (path , doc , log ):
62+ return aprovider_factory (collab_endpoint .port (), path , ydoc = doc , log = log )
63+
64+
1465 @router .websocket ("/collaboration/{path:path}" )
1566 async def websocket_endpoint (path : str , websocket : WebSocket ):
1667 await websocket .accept ()
@@ -21,7 +72,7 @@ async def websocket_endpoint(path: str, websocket: WebSocket):
2172 async def get_websocket_server ():
2273 global WEBSOCKET_SERVER
2374 if WEBSOCKET_SERVER is None :
24- WEBSOCKET_SERVER = WebsocketServer ()
75+ WEBSOCKET_SERVER = WebsocketServer (provider_factory = provider_factory )
2576 create_task (WEBSOCKET_SERVER .start ())
2677 await WEBSOCKET_SERVER .started .wait ()
2778 return WEBSOCKET_SERVER
0 commit comments