Skip to content

Commit 8a2bd59

Browse files
committed
refactor!: remove View and Model Store & make Guild._from_data into classmethod
1 parent a824153 commit 8a2bd59

File tree

10 files changed

+30
-161
lines changed

10 files changed

+30
-161
lines changed

discord/app/cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
from ..guild import Guild
3636
from ..poll import Poll
3737
from ..sticker import GuildSticker, Sticker
38-
from ..ui.modal import Modal, ModalStore
39-
from ..ui.view import View, ViewStore
38+
from ..ui.modal import Modal
39+
from ..ui.view import View
4040
from ..user import User
4141
from ..types.user import User as UserPayload
4242
from ..types.emoji import Emoji as EmojiPayload

discord/app/state.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@
7373
from ..stage_instance import StageInstance
7474
from ..sticker import GuildSticker
7575
from ..threads import Thread, ThreadMember
76-
from ..ui.modal import Modal, ModalStore
77-
from ..ui.view import View, ViewStore
76+
from ..ui.modal import Modal
77+
from ..ui.view import View
7878
from ..user import ClientUser, User
7979

8080
if TYPE_CHECKING:

discord/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1584,7 +1584,7 @@ async def fetch_guild(self, guild_id: int, /, *, with_counts=True) -> Guild:
15841584
Getting the guild failed.
15851585
"""
15861586
data = await self.http.get_guild(guild_id, with_counts=with_counts)
1587-
return Guild(data=data, state=self._connection)
1587+
return await Guild._from_data(data=data, state=self._connection)
15881588

15891589
async def create_guild(
15901590
self,
@@ -1633,7 +1633,7 @@ async def create_guild(
16331633
data = await self.http.create_from_template(code, name, icon_base64)
16341634
else:
16351635
data = await self.http.create_guild(name, icon_base64)
1636-
return Guild(data=data, state=self._connection)
1636+
return await Guild._from_data(data=data, state=self._connection)
16371637

16381638
async def fetch_stage_instance(self, channel_id: int, /) -> StageInstance:
16391639
"""|coro|

discord/events/gateway.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def __load__(cls, data: dict[str, Any], state: ConnectionState) -> Self:
8282
self.guilds = []
8383

8484
for guild_data in data["guilds"]:
85-
guild = await Guild(data=guild_data, state=state)._from_data(guild_data)
85+
guild = await Guild._from_data(guild_data, state)
8686
self.guilds.append(guild)
8787
await state._add_guild(guild)
8888

@@ -115,7 +115,7 @@ async def __load__(cls, data: GuildPayload, state: ConnectionState) -> Self:
115115
self = cls()
116116
guild = await state._get_guild(int(data["id"]))
117117
if guild is None:
118-
guild = await Guild(data=data, state=state)._from_data(data)
118+
guild = await Guild._from_data(data, state)
119119
await state._add_guild(guild)
120120
self.guild = guild
121121
self.__dict__.update(self.guild.__dict__)

discord/guild.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -298,19 +298,6 @@ class Guild(Hashable):
298298
3: _GuildLimit(emoji=250, stickers=60, bitrate=384e3, filesize=104_857_600),
299299
}
300300

301-
def __init__(self, *, data: GuildPayload, state: ConnectionState):
302-
# NOTE:
303-
# Adding an attribute here and getting an AttributeError saying
304-
# the attr doesn't exist? it has something to do with the order
305-
# of the attr in __slots__
306-
307-
self._channels: dict[int, GuildChannel] = {}
308-
self._members: dict[int, Member] = {}
309-
self._scheduled_events: dict[int, ScheduledEvent] = {}
310-
self._voice_states: dict[int, VoiceState] = {}
311-
self._threads: dict[int, Thread] = {}
312-
self._state: ConnectionState = state
313-
314301
def _add_channel(self, channel: GuildChannel, /) -> None:
315302
self._channels[channel.id] = channel
316303

@@ -448,7 +435,20 @@ def _remove_role(self, role_id: int, /) -> Role:
448435

449436
return role
450437

451-
async def _from_data(self, guild: GuildPayload) -> Self:
438+
@classmethod
439+
async def _from_data(cls, guild: GuildPayload, state: ConnectionState) -> Self:
440+
self = cls()
441+
# NOTE:
442+
# Adding an attribute here and getting an AttributeError saying
443+
# the attr doesn't exist? it has something to do with the order
444+
# of the attr in __slots__
445+
446+
self._channels: dict[int, GuildChannel] = {}
447+
self._members: dict[int, Member] = {}
448+
self._scheduled_events: dict[int, ScheduledEvent] = {}
449+
self._voice_states: dict[int, VoiceState] = {}
450+
self._threads: dict[int, Thread] = {}
451+
self._state: ConnectionState = state
452452
member_count = guild.get("member_count")
453453
# Either the payload includes member_count, or it hasn't been set yet.
454454
# Prevents valid _member_count from suddenly changing to None
@@ -1906,7 +1906,7 @@ async def edit(
19061906
fields["features"] = features
19071907

19081908
data = await http.edit_guild(self.id, reason=reason, **fields)
1909-
return Guild(data=data, state=self._state)
1909+
return Guild._from_data(data=data, state=self._state)
19101910

19111911
async def fetch_channels(self) -> Sequence[GuildChannel]:
19121912
"""|coro|

discord/interactions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ async def load_data(self):
235235
self._guild: Guild | None = None
236236
self._guild_data = data.get("guild")
237237
if self.guild is None and self._guild_data:
238-
self._guild = Guild(data=self._guild_data, state=self._state)
238+
self._guild = await Guild._from_data(data=self._guild_data, state=self._state)
239239

240240
# TODO: there's a potential data loss here
241241
if self.guild_id:

discord/iterators.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -638,10 +638,10 @@ def _get_retrieve(self):
638638
self.retrieve = r
639639
return r > 0
640640

641-
def create_guild(self, data):
641+
async def create_guild(self, data):
642642
from .guild import Guild
643643

644-
return Guild(state=self.state, data=data)
644+
return await Guild._from_data(state=self.state, data=data)
645645

646646
async def fill_guilds(self):
647647
if self._get_retrieve():
@@ -653,7 +653,7 @@ async def fill_guilds(self):
653653
data = filter(self._filter, data)
654654

655655
for element in data:
656-
await self.guilds.put(self.create_guild(element))
656+
await self.guilds.put(await self.create_guild(element))
657657

658658
async def _retrieve_guilds(self, retrieve) -> list[Guild]:
659659
"""Retrieve guilds and update next parameters."""

discord/template.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ async def from_data(cls, state: ConnectionState, data: TemplatePayload) -> None:
156156
source_serialised["id"] = guild_id
157157
state = _PartialTemplateState(state=self._state)
158158
# Guild expects a ConnectionState, we're passing a _PartialTemplateState
159-
self.source_guild = Guild(data=source_serialised, state=state) # type: ignore
159+
self.source_guild = await Guild._from_data(data=source_serialised, state=state) # type: ignore
160160
else:
161161
self.source_guild = guild
162162

@@ -200,7 +200,7 @@ async def create_guild(self, name: str, icon: Any = None) -> Guild:
200200
icon = _bytes_to_base64_data(icon)
201201

202202
data = await self._state.http.create_from_template(self.code, name, icon)
203-
return Guild(data=data, state=self._state)
203+
return await Guild._from_data(data=data, state=self._state)
204204

205205
async def sync(self) -> Template:
206206
"""|coro|

discord/ui/modal.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from .input_text import InputText
1313

1414
__all__ = (
15-
"Modal",
16-
"ModalStore",
15+
"Modal"
1716
)
1817

1918

@@ -69,16 +68,6 @@ def __init__(
6968
self.__timeout_task: asyncio.Task[None] | None = None
7069
self.loop = asyncio.get_event_loop()
7170

72-
def _start_listening_from_store(self, store: ModalStore) -> None:
73-
self.__cancel_callback = partial(store.remove_modal)
74-
if self.timeout:
75-
loop = asyncio.get_running_loop()
76-
if self.__timeout_task is not None:
77-
self.__timeout_task.cancel()
78-
79-
self.__timeout_expiry = time.monotonic() + self.timeout
80-
self.__timeout_task = loop.create_task(self.__timeout_task_impl())
81-
8271
async def __timeout_task_impl(self) -> None:
8372
while True:
8473
# Guard just in case someone changes the value of the timeout at runtime
@@ -305,40 +294,3 @@ def remove_item(self, item: InputText) -> None:
305294

306295
def clear(self) -> None:
307296
self.weights = [0, 0, 0, 0, 0]
308-
309-
310-
class ModalStore:
311-
def __init__(self, state: ConnectionState) -> None:
312-
# (user_id, custom_id) : Modal
313-
self._modals: dict[tuple[int, str], Modal] = {}
314-
self._state: ConnectionState = state
315-
316-
def add_modal(self, modal: Modal, user_id: int):
317-
self._modals[(user_id, modal.custom_id)] = modal
318-
modal._start_listening_from_store(self)
319-
320-
def remove_modal(self, modal: Modal, user_id):
321-
modal.stop()
322-
self._modals.pop((user_id, modal.custom_id))
323-
324-
async def dispatch(self, user_id: int, custom_id: str, interaction: Interaction):
325-
key = (user_id, custom_id)
326-
value = self._modals.get(key)
327-
if value is None:
328-
return
329-
330-
try:
331-
components = [
332-
component
333-
for parent_component in interaction.data["components"]
334-
for component in parent_component["components"]
335-
]
336-
for component in components:
337-
for child in value.children:
338-
if child.custom_id == component["custom_id"]: # type: ignore
339-
child.refresh_state(component)
340-
break
341-
await value.callback(interaction)
342-
self.remove_modal(value, user_id)
343-
except Exception as e:
344-
return await value.on_error(e, interaction)

discord/ui/view.py

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -427,15 +427,6 @@ async def _scheduled_task(self, item: Item, interaction: Interaction):
427427
except Exception as e:
428428
return await self.on_error(e, item, interaction)
429429

430-
def _start_listening_from_store(self, store: ViewStore) -> None:
431-
self.__cancel_callback = partial(store.remove_view)
432-
if self.timeout:
433-
loop = asyncio.get_running_loop()
434-
if self.__timeout_task is not None:
435-
self.__timeout_task.cancel()
436-
437-
self.__timeout_expiry = time.monotonic() + self.timeout
438-
self.__timeout_task = loop.create_task(self.__timeout_task_impl())
439430

440431
def _dispatch_timeout(self):
441432
if self.__stopped.done():
@@ -563,77 +554,3 @@ def message(self):
563554
def message(self, value):
564555
self._message = value
565556

566-
567-
class ViewStore:
568-
def __init__(self, state: ConnectionState):
569-
# (component_type, message_id, custom_id): (View, Item)
570-
self._views: dict[tuple[int, int | None, str], tuple[View, Item]] = {}
571-
# message_id: View
572-
self._synced_message_views: dict[int, View] = {}
573-
self._state: ConnectionState = state
574-
575-
@property
576-
def persistent_views(self) -> Sequence[View]:
577-
views = {
578-
view.id: view
579-
for (_, (view, _)) in self._views.items()
580-
if view.is_persistent()
581-
}
582-
return list(views.values())
583-
584-
def __verify_integrity(self):
585-
to_remove: list[tuple[int, int | None, str]] = []
586-
for k, (view, _) in self._views.items():
587-
if view.is_finished():
588-
to_remove.append(k)
589-
590-
for k in to_remove:
591-
del self._views[k]
592-
593-
def add_view(self, view: View, message_id: int | None = None):
594-
self.__verify_integrity()
595-
596-
view._start_listening_from_store(self)
597-
for item in view.children:
598-
if item.is_dispatchable():
599-
self._views[(item.type.value, message_id, item.custom_id)] = (view, item) # type: ignore
600-
601-
if message_id is not None:
602-
self._synced_message_views[message_id] = view
603-
604-
def remove_view(self, view: View):
605-
for item in view.children:
606-
if item.is_dispatchable():
607-
self._views.pop((item.type.value, item.custom_id), None) # type: ignore
608-
609-
for key, value in self._synced_message_views.items():
610-
if value.id == view.id:
611-
del self._synced_message_views[key]
612-
break
613-
614-
def dispatch(self, component_type: int, custom_id: str, interaction: Interaction):
615-
self.__verify_integrity()
616-
message_id: int | None = interaction.message and interaction.message.id
617-
key = (component_type, message_id, custom_id)
618-
# Fallback to None message_id searches in case a persistent view
619-
# was added without an associated message_id
620-
value = self._views.get(key) or self._views.get(
621-
(component_type, None, custom_id)
622-
)
623-
if value is None:
624-
return
625-
626-
view, item = value
627-
item.refresh_state(interaction)
628-
view._dispatch_item(item, interaction)
629-
630-
def is_message_tracked(self, message_id: int):
631-
return message_id in self._synced_message_views
632-
633-
def remove_message_tracking(self, message_id: int) -> View | None:
634-
return self._synced_message_views.pop(message_id, None)
635-
636-
def update_from_message(self, message_id: int, components: list[ComponentPayload]):
637-
# pre-req: is_message_tracked == true
638-
view = self._synced_message_views[message_id]
639-
view.refresh([_component_factory(d) for d in components])

0 commit comments

Comments
 (0)