Skip to content

Commit a3866fb

Browse files
committed
🚧 OMG THIS WORKS
1 parent 8dd4b4f commit a3866fb

33 files changed

+2055
-76
lines changed

discord/asset.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from typing import TYPE_CHECKING, Any, Literal
3131

3232
import yarl
33+
from typing_extensions import Final, override
3334

3435
from . import utils
3536
from .errors import DiscordException, InvalidArgument
@@ -39,6 +40,7 @@
3940
if TYPE_CHECKING:
4041
ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"]
4142
ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"]
43+
from .state import ConnectionState
4244

4345
VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"})
4446
VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"}
@@ -49,7 +51,7 @@
4951

5052
class AssetMixin:
5153
url: str
52-
_state: Any | None
54+
_state: ConnectionState | None
5355

5456
async def read(self) -> bytes:
5557
"""|coro|
@@ -77,7 +79,9 @@ async def read(self) -> bytes:
7779

7880
async def save(
7981
self,
80-
fp: str | bytes | os.PathLike | io.BufferedIOBase,
82+
fp: (
83+
str | bytes | os.PathLike | io.BufferedIOBase
84+
), # pyright: ignore [reportMissingTypeArgument]
8185
*,
8286
seek_begin: bool = True,
8387
) -> int:
@@ -117,7 +121,7 @@ async def save(
117121
fp.seek(0)
118122
return written
119123
else:
120-
with open(fp, "wb") as f:
124+
with open(fp, "wb") as f: # pyright: ignore [reportUnknownArgumentType]
121125
return f.write(data)
122126

123127

@@ -154,16 +158,23 @@ class Asset(AssetMixin):
154158
"_key",
155159
)
156160

157-
BASE = "https://cdn.discordapp.com"
161+
BASE: Final = "https://cdn.discordapp.com"
158162

159-
def __init__(self, state, *, url: str, key: str, animated: bool = False):
160-
self._state = state
161-
self._url = url
162-
self._animated = animated
163-
self._key = key
163+
def __init__(
164+
self,
165+
state: ConnectionState | None,
166+
*,
167+
url: str,
168+
key: str,
169+
animated: bool = False,
170+
):
171+
self._state: ConnectionState | None = state
172+
self._url: str = url
173+
self._animated: bool = animated
174+
self._key: str = key
164175

165176
@classmethod
166-
def _from_default_avatar(cls, state, index: int) -> Asset:
177+
def _from_default_avatar(cls, state: ConnectionState, index: int) -> Asset:
167178
return cls(
168179
state,
169180
url=f"{cls.BASE}/embed/avatars/{index}.png",
@@ -172,7 +183,7 @@ def _from_default_avatar(cls, state, index: int) -> Asset:
172183
)
173184

174185
@classmethod
175-
def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
186+
def _from_avatar(cls, state: ConnectionState, user_id: int, avatar: str) -> Asset:
176187
animated = avatar.startswith("a_")
177188
format = "gif" if animated else "png"
178189
return cls(
@@ -184,7 +195,10 @@ def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset:
184195

185196
@classmethod
186197
def _from_avatar_decoration(
187-
cls, state, user_id: int, avatar_decoration: str
198+
cls,
199+
state: ConnectionState,
200+
user_id: int,
201+
avatar_decoration: str, # pyright: ignore [reportUnusedParameter]
188202
) -> Asset:
189203
animated = avatar_decoration.startswith("a_")
190204
endpoint = (
@@ -201,7 +215,7 @@ def _from_avatar_decoration(
201215

202216
@classmethod
203217
def _from_guild_avatar(
204-
cls, state, guild_id: int, member_id: int, avatar: str
218+
cls, state: ConnectionState, guild_id: int, member_id: int, avatar: str
205219
) -> Asset:
206220
animated = avatar.startswith("a_")
207221
format = "gif" if animated else "png"
@@ -214,7 +228,7 @@ def _from_guild_avatar(
214228

215229
@classmethod
216230
def _from_guild_banner(
217-
cls, state, guild_id: int, member_id: int, banner: str
231+
cls, state: ConnectionState, guild_id: int, member_id: int, banner: str
218232
) -> Asset:
219233
animated = banner.startswith("a_")
220234
format = "gif" if animated else "png"
@@ -226,7 +240,9 @@ def _from_guild_banner(
226240
)
227241

228242
@classmethod
229-
def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
243+
def _from_icon(
244+
cls, state: ConnectionState, object_id: int, icon_hash: str, path: str
245+
) -> Asset:
230246
return cls(
231247
state,
232248
url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024",
@@ -235,7 +251,9 @@ def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset:
235251
)
236252

237253
@classmethod
238-
def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset:
254+
def _from_cover_image(
255+
cls, state: ConnectionState, object_id: int, cover_image_hash: str
256+
) -> Asset:
239257
return cls(
240258
state,
241259
url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024",
@@ -244,7 +262,9 @@ def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asse
244262
)
245263

246264
@classmethod
247-
def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset:
265+
def _from_guild_image(
266+
cls, state: ConnectionState, guild_id: int, image: str, path: str
267+
) -> Asset:
248268
animated = False
249269
format = "png"
250270
if path == "banners":
@@ -259,7 +279,9 @@ def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset
259279
)
260280

261281
@classmethod
262-
def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
282+
def _from_guild_icon(
283+
cls, state: ConnectionState, guild_id: int, icon_hash: str
284+
) -> Asset:
263285
animated = icon_hash.startswith("a_")
264286
format = "gif" if animated else "png"
265287
return cls(
@@ -270,7 +292,7 @@ def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset:
270292
)
271293

272294
@classmethod
273-
def _from_sticker_banner(cls, state, banner: int) -> Asset:
295+
def _from_sticker_banner(cls, state: ConnectionState, banner: int) -> Asset:
274296
return cls(
275297
state,
276298
url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png",
@@ -279,7 +301,9 @@ def _from_sticker_banner(cls, state, banner: int) -> Asset:
279301
)
280302

281303
@classmethod
282-
def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
304+
def _from_user_banner(
305+
cls, state: ConnectionState, user_id: int, banner_hash: str
306+
) -> Asset:
283307
animated = banner_hash.startswith("a_")
284308
format = "gif" if animated else "png"
285309
return cls(
@@ -291,7 +315,7 @@ def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset:
291315

292316
@classmethod
293317
def _from_scheduled_event_image(
294-
cls, state, event_id: int, cover_hash: str
318+
cls, state: ConnectionState, event_id: int, cover_hash: str
295319
) -> Asset:
296320
return cls(
297321
state,
@@ -300,24 +324,29 @@ def _from_scheduled_event_image(
300324
animated=False,
301325
)
302326

327+
@override
303328
def __str__(self) -> str:
304329
return self._url
305330

306331
def __len__(self) -> int:
307332
return len(self._url)
308333

334+
@override
309335
def __repr__(self):
310336
shorten = self._url.replace(self.BASE, "")
311337
return f"<Asset url={shorten!r}>"
312338

313-
def __eq__(self, other):
339+
@override
340+
def __eq__(self, other: Any): # pyright: ignore [reportExplicitAny]
314341
return isinstance(other, Asset) and self._url == other._url
315342

343+
@override
316344
def __hash__(self):
317345
return hash(self._url)
318346

319347
@property
320-
def url(self) -> str:
348+
@override
349+
def url(self) -> str: # pyright: ignore [reportIncompatibleVariableOverride]
321350
"""Returns the underlying URL of the asset."""
322351
return self._url
323352

discord/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
import aiohttp
3737

38-
from . import utils
38+
from . import models, utils
3939
from .activity import ActivityTypes, BaseActivity, create_activity
4040
from .appinfo import AppInfo, PartialAppInfo
4141
from .application_role_connection import ApplicationRoleConnectionMetadata
@@ -1840,7 +1840,7 @@ async def fetch_user(self, user_id: int, /) -> User:
18401840
:exc:`HTTPException`
18411841
Fetching the user failed.
18421842
"""
1843-
data = await self.http.get_user(user_id)
1843+
data: models.User = await self.http.get_user(user_id)
18441844
return User(state=self._connection, data=data)
18451845

18461846
async def fetch_channel(

discord/gateway.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@
3535
import traceback
3636
import zlib
3737
from collections import deque, namedtuple
38+
from typing import Any
3839

3940
import aiohttp
41+
from pydantic import BaseModel
42+
43+
from discord import models
4044

4145
from . import utils
4246
from .activity import BaseActivity
@@ -548,11 +552,20 @@ async def received_message(self, msg, /):
548552
)
549553

550554
try:
551-
func = self._discord_parsers[event]
555+
func: Any = self._discord_parsers[event]
552556
except KeyError:
553557
_log.debug("Unknown event %s.", event)
554558
else:
555-
func(data)
559+
if hasattr(func, "_supports_model") and issubclass(
560+
func._supports_model, models.gateway.GatewayEvent
561+
):
562+
func(
563+
func._supports_model(
564+
**msg
565+
).d # pyright: ignore [reportUnknownMemberType, reportAttributeAccessIssue]
566+
)
567+
else:
568+
func(data)
556569

557570
# remove the dispatched listeners
558571
removed = []

discord/http.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
from urllib.parse import quote as _uriquote
3434

3535
import aiohttp
36+
from typing_extensions import overload
3637

37-
from . import __version__, utils
38+
from . import __version__, models, utils
3839
from .errors import (
3940
DiscordServerError,
4041
Forbidden,
@@ -52,6 +53,8 @@
5253
if TYPE_CHECKING:
5354
from types import TracebackType
5455

56+
from pydantic import BaseModel
57+
5558
from .enums import AuditLogAction, InteractionResponseType
5659
from .file import File
5760
from .types import (
@@ -87,9 +90,11 @@
8790
T = TypeVar("T")
8891
BE = TypeVar("BE", bound=BaseException)
8992
MU = TypeVar("MU", bound="MaybeUnlock")
90-
Response = Coroutine[Any, Any, T]
93+
94+
Response = Coroutine[Any, Any, T] # pyright: ignore [reportExplicitAny]
9195

9296
API_VERSION: int = 10
97+
BM = TypeVar("BM", bound=type["BaseModel"])
9398

9499

95100
async def json_or_text(response: aiohttp.ClientResponse) -> dict[str, Any] | str:
@@ -157,7 +162,7 @@ def __exit__(
157162

158163
# For some reason, the Discord voice websocket expects this header to be
159164
# completely lowercase while aiohttp respects spec and does it as case-insensitive
160-
aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore
165+
aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore # pyright: ignore [reportAttributeAccessIssue]
161166

162167

163168
class HTTPClient:
@@ -215,14 +220,36 @@ async def ws_connect(self, url: str, *, compress: int = 0) -> Any:
215220

216221
return await self.__session.ws_connect(url, **kwargs)
217222

223+
@overload
218224
async def request(
219225
self,
220226
route: Route,
221227
*,
222228
files: Sequence[File] | None = None,
223229
form: Iterable[dict[str, Any]] | None = None,
230+
model: None = None,
224231
**kwargs: Any,
225-
) -> Any:
232+
) -> Any: ...
233+
234+
@overload
235+
async def request(
236+
self,
237+
route: Route,
238+
*,
239+
files: None = ...,
240+
form: None = ...,
241+
model: BM,
242+
**kwargs: Any,
243+
) -> BM: ...
244+
async def request(
245+
self,
246+
route: Route,
247+
*,
248+
files: Sequence[File] | None = None,
249+
form: Iterable[dict[str, Any]] | None = None,
250+
model: BM | None = None,
251+
**kwargs: Any,
252+
) -> Any | BM:
226253
bucket = route.bucket
227254
method = route.method
228255
url = route.url
@@ -318,6 +345,10 @@ async def request(
318345
# the request was successful so just return the text/json
319346
if 300 > response.status >= 200:
320347
_log.debug("%s %s has received %s", method, url, data)
348+
if model:
349+
return model(
350+
**data
351+
) # pyright: ignore [reportCallIssue]
321352
return data
322353

323354
# we are being rate limited
@@ -409,7 +440,7 @@ async def close(self) -> None:
409440

410441
# login management
411442

412-
async def static_login(self, token: str) -> user.User:
443+
async def static_login(self, token: str) -> models.User:
413444
# Necessary to get aiohttp to stop complaining about session creation
414445
self.__session = aiohttp.ClientSession(
415446
connector=self.connector, ws_response_class=DiscordClientWebSocketResponse
@@ -418,7 +449,7 @@ async def static_login(self, token: str) -> user.User:
418449
self.token = token
419450

420451
try:
421-
data = await self.request(Route("GET", "/users/@me"))
452+
data = await self.request(Route("GET", "/users/@me"), model=models.User)
422453
except HTTPException as exc:
423454
self.token = old_token
424455
if exc.status == 401:
@@ -3173,5 +3204,7 @@ async def get_bot_gateway(
31733204
value = "{0}?encoding={1}&v={2}"
31743205
return data["shards"], value.format(data["url"], encoding, API_VERSION)
31753206

3176-
def get_user(self, user_id: Snowflake) -> Response[user.User]:
3177-
return self.request(Route("GET", "/users/{user_id}", user_id=user_id))
3207+
def get_user(self, user_id: Snowflake) -> Response[models.User]:
3208+
return self.request(
3209+
Route("GET", "/users/{user_id}", user_id=user_id), model=models.User
3210+
)

0 commit comments

Comments
 (0)