Skip to content
136 changes: 90 additions & 46 deletions discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@
from .ui.view import View
from .user import ClientUser, User
from .utils import MISSING
from .utils.private import SequenceProxy, bytes_to_base64_data, resolve_invite, resolve_template
from .utils.private import (
SequenceProxy,
bytes_to_base64_data,
resolve_invite,
resolve_template,
)
from .voice_client import VoiceClient
from .webhook import Webhook
from .widget import Widget
Expand Down Expand Up @@ -234,8 +239,12 @@ def __init__(
self._banner_module = options.get("banner_module")
# self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop
self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = {}
self.loop: asyncio.AbstractEventLoop = (
asyncio.get_event_loop() if loop is None else loop
)
self._listeners: dict[
str, list[tuple[asyncio.Future, Callable[..., bool]]]
] = {}
self.shard_id: int | None = options.get("shard_id")
self.shard_count: int | None = options.get("shard_count")

Expand All @@ -253,7 +262,9 @@ def __init__(

self._handlers: dict[str, Callable] = {"ready": self._handle_ready}

self._hooks: dict[str, Callable] = {"before_identify": self._call_before_identify_hook}
self._hooks: dict[str, Callable] = {
"before_identify": self._call_before_identify_hook
}

self._enable_debug_events: bool = options.pop("enable_debug_events", False)
self._connection: ConnectionState = self._get_state(**options)
Expand Down Expand Up @@ -292,7 +303,9 @@ async def __aexit__(

# internals

def _get_websocket(self, guild_id: int | None = None, *, shard_id: int | None = None) -> DiscordWebSocket:
def _get_websocket(
self, guild_id: int | None = None, *, shard_id: int | None = None
) -> DiscordWebSocket:
return self.ws

def _get_state(self, **options: Any) -> ConnectionState:
Expand Down Expand Up @@ -541,7 +554,9 @@ async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None:
print(f"Ignoring exception in {event_method}", file=sys.stderr)
traceback.print_exc()

async def on_view_error(self, error: Exception, item: Item, interaction: Interaction) -> None:
async def on_view_error(
self, error: Exception, item: Item, interaction: Interaction
) -> None:
"""|coro|

The default view error handler provided by the client.
Expand All @@ -553,7 +568,9 @@ async def on_view_error(self, error: Exception, item: Item, interaction: Interac
f"Ignoring exception in view {interaction.view} for item {item}:",
file=sys.stderr,
)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
traceback.print_exception(
error.__class__, error, error.__traceback__, file=sys.stderr
)

async def on_modal_error(self, error: Exception, interaction: Interaction) -> None:
"""|coro|
Expand All @@ -565,17 +582,23 @@ async def on_modal_error(self, error: Exception, interaction: Interaction) -> No
"""

print(f"Ignoring exception in modal {interaction.modal}:", file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
traceback.print_exception(
error.__class__, error, error.__traceback__, file=sys.stderr
)

# hooks

async def _call_before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None:
async def _call_before_identify_hook(
self, shard_id: int | None, *, initial: bool = False
) -> None:
# This hook is an internal hook that actually calls the public one.
# It allows the library to have its own hook without stepping on the
# toes of those who need to override their own hook.
await self.before_identify_hook(shard_id, initial=initial)

async def before_identify_hook(self, shard_id: int | None, *, initial: bool = False) -> None:
async def before_identify_hook(
self, shard_id: int | None, *, initial: bool = False
) -> None:
"""|coro|

A hook that is called before IDENTIFYing a session. This is useful
Expand Down Expand Up @@ -622,14 +645,19 @@ async def login(self, token: str) -> None:
passing status code.
"""
if not isinstance(token, str):
raise TypeError(f"token must be of type str, not {token.__class__.__name__}")
raise TypeError(
f"token must be of type str, not {token.__class__.__name__}"
)

_log.info("logging in using static token")

data = await self.http.static_login(token.strip())
self._connection.user = ClientUser(state=self._connection, data=data)

print_banner(bot_name=self._connection.user.display_name, module=self._banner_module or "discord")
print_banner(
bot_name=self._connection.user.display_name,
module=self._banner_module or "discord",
)
start_logging(self._flavor, debug=self._debug)

async def connect(self, *, reconnect: bool = True) -> None:
Expand Down Expand Up @@ -726,7 +754,9 @@ async def connect(self, *, reconnect: bool = True) -> None:
# This is apparently what the official Discord client does.
if self.ws is None:
continue
ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id)
ws_params.update(
sequence=self.ws.sequence, resume=True, session=self.ws.session_id
)

async def close(self) -> None:
"""|coro|
Expand Down Expand Up @@ -894,7 +924,9 @@ def allowed_mentions(self, value: AllowedMentions | None) -> None:
if value is None or isinstance(value, AllowedMentions):
self._connection.allowed_mentions = value
else:
raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}")
raise TypeError(
f"allowed_mentions must be AllowedMentions not {value.__class__!r}"
)

@property
def intents(self) -> Intents:
Expand Down Expand Up @@ -968,7 +1000,9 @@ def get_message(self, id: int, /) -> Message | None:
"""
return self._connection._get_message(id)

def get_partial_messageable(self, id: int, *, type: ChannelType | None = None) -> PartialMessageable:
def get_partial_messageable(
self, id: int, *, type: ChannelType | None = None
) -> PartialMessageable:
"""Returns a partial messageable with the given channel ID.

This is useful if you have a channel_id but don't want to do an API call
Expand Down Expand Up @@ -1130,24 +1164,6 @@ def get_all_members(self) -> Generator[Member]:
for guild in self.guilds:
yield from guild.members

async def get_or_fetch_user(self, id: int, /) -> User | None:
"""|coro|

Looks up a user in the user cache or fetches if not found.

Parameters
----------
id: :class:`int`
The ID to search for.

Returns
-------
Optional[:class:`~discord.User`]
The user or ``None`` if not found.
"""

return await utils.get_or_fetch(obj=self, attr="user", id=id, default=None)

# listeners/waiters

async def wait_until_ready(self) -> None:
Expand Down Expand Up @@ -1315,7 +1331,9 @@ async def my_message(message):
name,
)

def remove_listener(self, func: Coro, name: str | utils.Undefined = MISSING) -> None:
def remove_listener(
self, func: Coro, name: str | utils.Undefined = MISSING
) -> None:
"""Removes a listener from the pool of listeners.

Parameters
Expand All @@ -1335,7 +1353,9 @@ def remove_listener(self, func: Coro, name: str | utils.Undefined = MISSING) ->
except ValueError:
pass

def listen(self, name: str | utils.Undefined = MISSING, once: bool = False) -> Callable[[Coro], Coro]:
def listen(
self, name: str | utils.Undefined = MISSING, once: bool = False
) -> Callable[[Coro], Coro]:
"""A decorator that registers another function as an external
event listener. Basically this allows you to listen to multiple
events from different places e.g. such as :func:`.on_ready`
Expand Down Expand Up @@ -1547,7 +1567,9 @@ def fetch_guilds(

All parameters are optional.
"""
return GuildIterator(self, limit=limit, before=before, after=after, with_counts=with_counts)
return GuildIterator(
self, limit=limit, before=before, after=after, with_counts=with_counts
)

async def fetch_template(self, code: Template | str) -> Template:
"""|coro|
Expand Down Expand Up @@ -1866,7 +1888,9 @@ async def fetch_user(self, user_id: int, /) -> User:
data = await self.http.get_user(user_id)
return User(state=self._connection, data=data)

async def fetch_channel(self, channel_id: int, /) -> GuildChannel | PrivateChannel | Thread:
async def fetch_channel(
self, channel_id: int, /
) -> GuildChannel | PrivateChannel | Thread:
"""|coro|

Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
Expand Down Expand Up @@ -1897,7 +1921,9 @@ async def fetch_channel(self, channel_id: int, /) -> GuildChannel | PrivateChann

factory, ch_type = _threaded_channel_factory(data["type"])
if factory is None:
raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data))
raise InvalidData(
"Unknown channel type {type} for channel ID {id}.".format_map(data)
)

if ch_type in (ChannelType.group, ChannelType.private):
# the factory will be a DMChannel or GroupChannel here
Expand Down Expand Up @@ -1971,7 +1997,10 @@ async def fetch_premium_sticker_packs(self) -> list[StickerPack]:
Retrieving the sticker packs failed.
"""
data = await self.http.list_premium_sticker_packs()
return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]]
return [
StickerPack(state=self._connection, data=pack)
for pack in data["sticker_packs"]
]

async def create_dm(self, user: Snowflake) -> DMChannel:
"""|coro|
Expand Down Expand Up @@ -2031,7 +2060,9 @@ def add_view(self, view: View, *, message_id: int | None = None) -> None:
raise TypeError(f"expected an instance of View not {view.__class__!r}")

if not view.is_persistent():
raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout")
raise ValueError(
"View is not persistent. Items need to have a custom_id set and View must have no timeout"
)

self._connection.store_view(view, message_id)

Expand All @@ -2057,7 +2088,9 @@ async def fetch_role_connection_metadata_records(
List[:class:`.ApplicationRoleConnectionMetadata`]
The bot's role connection metadata records.
"""
data = await self._connection.http.get_application_role_connection_metadata_records(self.application_id)
data = await self._connection.http.get_application_role_connection_metadata_records(
self.application_id
)
return [ApplicationRoleConnectionMetadata.from_dict(r) for r in data]

async def update_role_connection_metadata_records(
Expand Down Expand Up @@ -2196,8 +2229,13 @@ async def fetch_emojis(self) -> list[AppEmoji]:
List[:class:`AppEmoji`]
The retrieved emojis.
"""
data = await self._connection.http.get_all_application_emojis(self.application_id)
return [self._connection.maybe_store_app_emoji(self.application_id, d) for d in data["items"]]
data = await self._connection.http.get_all_application_emojis(
self.application_id
)
return [
self._connection.maybe_store_app_emoji(self.application_id, d)
for d in data["items"]
]

async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji:
"""|coro|
Expand All @@ -2221,7 +2259,9 @@ async def fetch_emoji(self, emoji_id: int, /) -> AppEmoji:
HTTPException
An error occurred fetching the emoji.
"""
data = await self._connection.http.get_application_emoji(self.application_id, emoji_id)
data = await self._connection.http.get_application_emoji(
self.application_id, emoji_id
)
return self._connection.maybe_store_app_emoji(self.application_id, data)

async def create_emoji(
Expand Down Expand Up @@ -2256,7 +2296,9 @@ async def create_emoji(
"""

img = bytes_to_base64_data(image)
data = await self._connection.http.create_application_emoji(self.application_id, name, img)
data = await self._connection.http.create_application_emoji(
self.application_id, name, img
)
return self._connection.maybe_store_app_emoji(self.application_id, data)

async def delete_emoji(self, emoji: Snowflake) -> None:
Expand All @@ -2275,6 +2317,8 @@ async def delete_emoji(self, emoji: Snowflake) -> None:
An error occurred deleting the emoji.
"""

await self._connection.http.delete_application_emoji(self.application_id, emoji.id)
await self._connection.http.delete_application_emoji(
self.application_id, emoji.id
)
if self._connection.cache_app_emojis and self._connection.get_emoji(emoji.id):
self._connection.remove_emoji(emoji)
14 changes: 10 additions & 4 deletions discord/guild.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def _sync(self, data: GuildPayload) -> None:
if "channels" in data:
channels = data["channels"]
for c in channels:
factory, ch_type = _guild_channel_factory(c["type"])
factory, _ = _guild_channel_factory(c["type"])
if factory:
self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore

Expand Down Expand Up @@ -1020,7 +1020,10 @@ def get_member_named(self, name: str, /) -> Member | None:

# do the actual lookup and return if found
# if it isn't found then we'll do a full name lookup below.
result = utils.find(lambda m: m.name == name[:-5] and discriminator == potential_discriminator, members)
result = utils.find(
lambda m: m.name == name[:-5] and discriminator == potential_discriminator,
members,
)
if result is not None:
return result

Expand Down Expand Up @@ -1885,7 +1888,7 @@ async def fetch_channels(self) -> Sequence[GuildChannel]:
data = await self._state.http.get_all_guild_channels(self.id)

def convert(d):
factory, ch_type = _guild_channel_factory(d["type"])
factory, _ = _guild_channel_factory(d["type"])
if factory is None:
raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d))

Expand Down Expand Up @@ -3312,7 +3315,10 @@ async def widget(self) -> Widget:
return Widget(state=self._state, data=data)

async def edit_widget(
self, *, enabled: bool | utils.Undefined = MISSING, channel: Snowflake | None | utils.Undefined = MISSING
self,
*,
enabled: bool | utils.Undefined = MISSING,
channel: Snowflake | None | utils.Undefined = MISSING,
) -> None:
"""|coro|

Expand Down
2 changes: 1 addition & 1 deletion discord/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ def parse_channel_update(self, data) -> None:
)

def parse_channel_create(self, data) -> None:
factory, ch_type = _channel_factory(data["type"])
factory, _ = _channel_factory(data["type"])
if factory is None:
_log.debug(
"CHANNEL_CREATE referencing an unknown channel type %s. Discarding.",
Expand Down
Loading
Loading