Skip to content

Commit d53daf4

Browse files
fix: AiohttpClientInterface's request
1 parent 20e63e0 commit d53daf4

File tree

2 files changed

+24
-14
lines changed

2 files changed

+24
-14
lines changed

src/graia/amnesia/builtins/aiohttp.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import asyncio
24
import contextlib
35
import pathlib
@@ -38,7 +40,7 @@
3840
from graia.amnesia.transport.common.http import AbstractServerRequestIO
3941
from graia.amnesia.transport.common.http import HttpEndpoint as HttpEndpoint
4042
from graia.amnesia.transport.common.http.extra import HttpRequest, HttpResponse
41-
from graia.amnesia.transport.common.http.io import AbstactClientRequestIO
43+
from graia.amnesia.transport.common.http.io import AbstractClientRequestIO
4244
from graia.amnesia.transport.common.status import ConnectionStatus
4345
from graia.amnesia.transport.common.websocket import AbstractWebsocketIO
4446
from graia.amnesia.transport.common.websocket import (
@@ -84,11 +86,14 @@ async def wait_for_drop(self) -> None:
8486
await self.wait_for_update()
8587

8688

87-
class ClientRequestIO(AbstactClientRequestIO):
89+
class ClientRequestIO(AbstractClientRequestIO):
90+
rider: ClientConnectionRider[ClientResponse]
8891
response: ClientResponse
8992

90-
def __init__(self, response: ClientResponse) -> None:
91-
self.response = response
93+
def __init__(self, rider: ClientConnectionRider) -> None:
94+
assert rider.response
95+
self.rider = rider
96+
self.response = rider.response
9297

9398
async def read(self) -> bytes:
9499
return await self.response.read()
@@ -102,12 +107,18 @@ async def extra(self, signature):
102107
self.response.url,
103108
)
104109

110+
def close(self):
111+
self.rider.status.drop = True
112+
105113

106114
class ClientWebsocketIO(AbstractWebsocketIO):
115+
rider: ClientConnectionRider[ClientWebSocketResponse]
107116
connection: ClientWebSocketResponse
108117

109-
def __init__(self, connection: ClientWebSocketResponse) -> None:
110-
self.connection = connection
118+
def __init__(self, rider: ClientConnectionRider) -> None:
119+
assert rider.response
120+
self.rider = rider
121+
self.connection = rider.response
111122

112123
async def cookies(self) -> Dict[str, str]:
113124
return {k: v.value for k, v in self.connection._response.cookies.items()}
@@ -164,7 +175,7 @@ async def wait_for_ready(self):
164175
class ClientConnectionRider(TransportRider[str, T], Generic[T]):
165176
def __init__(
166177
self,
167-
interface: "AiohttpClientInterface",
178+
interface: AiohttpClientInterface,
168179
conn_func: Callable[..., AsyncContextManager[T]],
169180
call_param: Dict[str, Any],
170181
) -> None:
@@ -220,11 +231,10 @@ def io(self, id=None) -> ...:
220231
if not self.status.connected:
221232
raise RuntimeError("the connection is not ready, please await the instance to ensure connection")
222233
assert self.response
223-
self.status.drop = True
224234
if isinstance(self.response, ClientWebSocketResponse):
225-
return ClientWebsocketIO(self.response)
235+
return ClientWebsocketIO(self)
226236
elif isinstance(self.response, ClientResponse):
227-
return ClientRequestIO(self.response)
237+
return ClientRequestIO(self)
228238
else:
229239
raise TypeError("this response is not a ClientResponse or ClientWebSocketResponse")
230240

@@ -246,7 +256,7 @@ async def connection_manage(self):
246256
assert isinstance(
247257
self.response, ClientWebSocketResponse
248258
), f"{self.response} is not a ClientWebSocketResponse"
249-
io = ClientWebsocketIO(self.response)
259+
io = ClientWebsocketIO(self)
250260
await self.trigger_callbacks(WebsocketConnectEvent, io)
251261
with contextlib.suppress(ConnectionClosed):
252262
async for data in io.packets():
@@ -285,9 +295,9 @@ def use(self, transport: Transport):
285295

286296

287297
class AiohttpClientInterface(ExportInterface["AiohttpService"]):
288-
service: "AiohttpService"
298+
service: AiohttpService
289299

290-
def __init__(self, service: "AiohttpService") -> None:
300+
def __init__(self, service: AiohttpService) -> None:
291301
self.service = service
292302

293303
def request(

src/graia/amnesia/transport/common/http/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ async def cookies(self) -> Dict[str, str]:
2323
return req.cookies
2424

2525

26-
class AbstactClientRequestIO(ReadonlyIO[bytes]):
26+
class AbstractClientRequestIO(ReadonlyIO[bytes]):
2727
@abstractmethod
2828
async def read(self) -> bytes:
2929
raise NotImplementedError

0 commit comments

Comments
 (0)