1+ from __future__ import annotations
2+
13import asyncio
24import contextlib
35import pathlib
3840from graia .amnesia .transport .common .http import AbstractServerRequestIO
3941from graia .amnesia .transport .common .http import HttpEndpoint as HttpEndpoint
4042from graia .amnesia .transport .common .http .extra import HttpRequest , HttpResponse
41- from graia .amnesia .transport .common .http .io import AbstactClientRequestIO
43+ from graia .amnesia .transport .common .http .io import AbstractClientRequestIO
4244from graia .amnesia .transport .common .status import ConnectionStatus
4345from graia .amnesia .transport .common .websocket import AbstractWebsocketIO
4446from graia .amnesia .transport .common .websocket import (
@@ -84,11 +86,14 @@ async def wait_for_drop(self) -> None:
8486 await self .wait_for_update ()
8587
8688
87- class ClientRequestIO (AbstactClientRequestIO ):
89+ class ClientRequestIO (AbstractClientRequestIO ):
90+ rider : ClientConnectionRider [ClientResponse ]
8891 response : ClientResponse
8992
90- def __init__ (self , response : ClientResponse ) -> None :
91- self .response = response
93+ def __init__ (self , rider : ClientConnectionRider ) -> None :
94+ assert rider .response
95+ self .rider = rider
96+ self .response = rider .response
9297
9398 async def read (self ) -> bytes :
9499 return await self .response .read ()
@@ -102,12 +107,18 @@ async def extra(self, signature):
102107 self .response .url ,
103108 )
104109
110+ def close (self ):
111+ self .rider .status .drop = True
112+
105113
106114class ClientWebsocketIO (AbstractWebsocketIO ):
115+ rider : ClientConnectionRider [ClientWebSocketResponse ]
107116 connection : ClientWebSocketResponse
108117
109- def __init__ (self , connection : ClientWebSocketResponse ) -> None :
110- self .connection = connection
118+ def __init__ (self , rider : ClientConnectionRider ) -> None :
119+ assert rider .response
120+ self .rider = rider
121+ self .connection = rider .response
111122
112123 async def cookies (self ) -> Dict [str , str ]:
113124 return {k : v .value for k , v in self .connection ._response .cookies .items ()}
@@ -164,7 +175,7 @@ async def wait_for_ready(self):
164175class ClientConnectionRider (TransportRider [str , T ], Generic [T ]):
165176 def __init__ (
166177 self ,
167- interface : " AiohttpClientInterface" ,
178+ interface : AiohttpClientInterface ,
168179 conn_func : Callable [..., AsyncContextManager [T ]],
169180 call_param : Dict [str , Any ],
170181 ) -> None :
@@ -220,11 +231,10 @@ def io(self, id=None) -> ...:
220231 if not self .status .connected :
221232 raise RuntimeError ("the connection is not ready, please await the instance to ensure connection" )
222233 assert self .response
223- self .status .drop = True
224234 if isinstance (self .response , ClientWebSocketResponse ):
225- return ClientWebsocketIO (self . response )
235+ return ClientWebsocketIO (self )
226236 elif isinstance (self .response , ClientResponse ):
227- return ClientRequestIO (self . response )
237+ return ClientRequestIO (self )
228238 else :
229239 raise TypeError ("this response is not a ClientResponse or ClientWebSocketResponse" )
230240
@@ -246,7 +256,7 @@ async def connection_manage(self):
246256 assert isinstance (
247257 self .response , ClientWebSocketResponse
248258 ), f"{ self .response } is not a ClientWebSocketResponse"
249- io = ClientWebsocketIO (self . response )
259+ io = ClientWebsocketIO (self )
250260 await self .trigger_callbacks (WebsocketConnectEvent , io )
251261 with contextlib .suppress (ConnectionClosed ):
252262 async for data in io .packets ():
@@ -285,9 +295,9 @@ def use(self, transport: Transport):
285295
286296
287297class AiohttpClientInterface (ExportInterface ["AiohttpService" ]):
288- service : " AiohttpService"
298+ service : AiohttpService
289299
290- def __init__ (self , service : " AiohttpService" ) -> None :
300+ def __init__ (self , service : AiohttpService ) -> None :
291301 self .service = service
292302
293303 def request (
0 commit comments