diff --git a/discord/ui/view.py b/discord/ui/view.py index c54cb58f13..bbfa353478 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -32,7 +32,7 @@ import traceback from functools import partial from itertools import groupby -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterator, Sequence +from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterator, Sequence, TypeVar from ..components import ActionRow as ActionRowComponent from ..components import Button as ButtonComponent @@ -51,6 +51,8 @@ from ..state import ConnectionState from ..types.components import Component as ComponentPayload +V = TypeVar("V", bound="View", covariant=True) + def _walk_all_components(components: list[Component]) -> Iterator[Component]: for item in components: @@ -60,7 +62,7 @@ def _walk_all_components(components: list[Component]) -> Iterator[Component]: yield item -def _component_to_item(component: Component) -> Item: +def _component_to_item(component: Component) -> Item[V]: if isinstance(component, ButtonComponent): from .button import Button @@ -75,7 +77,7 @@ def _component_to_item(component: Component) -> Item: class _ViewWeights: __slots__ = ("weights",) - def __init__(self, children: list[Item]): + def __init__(self, children: list[Item[V]]): self.weights: list[int] = [0, 0, 0, 0, 0] key = lambda i: sys.maxsize if i.row is None else i.row @@ -84,14 +86,14 @@ def __init__(self, children: list[Item]): for item in group: self.add_item(item) - def find_open_space(self, item: Item) -> int: + def find_open_space(self, item: Item[V]) -> int: for index, weight in enumerate(self.weights): if weight + item.width <= 5: return index raise ValueError("could not find open space for item") - def add_item(self, item: Item) -> None: + def add_item(self, item: Item[V]) -> None: if item.row is not None: total = self.weights[item.row] + item.width if total > 5: @@ -105,7 +107,7 @@ def add_item(self, item: Item) -> None: self.weights[index] += item.width item._rendered_row = index - def remove_item(self, item: Item) -> None: + def remove_item(self, item: Item[V]) -> None: if item._rendered_row is not None: self.weights[item._rendered_row] -= item.width item._rendered_row = None @@ -163,15 +165,15 @@ def __init_subclass__(cls) -> None: def __init__( self, - *items: Item, + *items: Item[V], timeout: float | None = 180.0, disable_on_timeout: bool = False, ): self.timeout = timeout self.disable_on_timeout = disable_on_timeout - self.children: list[Item] = [] + self.children: list[Item[V]] = [] for func in self.__view_children_items__: - item: Item = func.__discord_ui_model_type__( + item: Item[V] = func.__discord_ui_model_type__( **func.__discord_ui_model_kwargs__ ) item.callback = partial(func, self, item) @@ -213,7 +215,7 @@ async def __timeout_task_impl(self) -> None: await asyncio.sleep(self.__timeout_expiry - now) def to_components(self) -> list[dict[str, Any]]: - def key(item: Item) -> int: + def key(item: Item[V]) -> int: return item._rendered_row or 0 children = sorted(self.children, key=key) @@ -267,7 +269,7 @@ def _expires_at(self) -> float | None: return time.monotonic() + self.timeout return None - def add_item(self, item: Item) -> None: + def add_item(self, item: Item[V]) -> None: """Adds an item to the view. Parameters @@ -295,7 +297,7 @@ def add_item(self, item: Item) -> None: item._view = self self.children.append(item) - def remove_item(self, item: Item) -> None: + def remove_item(self, item: Item[V]) -> None: """Removes an item from the view. Parameters @@ -316,7 +318,7 @@ def clear_items(self) -> None: self.children.clear() self.__weights.clear() - def get_item(self, custom_id: str) -> Item | None: + def get_item(self, custom_id: str) -> Item[V] | None: """Get an item from the view with the given custom ID. Alias for `utils.get(view.children, custom_id=custom_id)`. Parameters @@ -391,7 +393,7 @@ async def on_check_failure(self, interaction: Interaction) -> None: """ async def on_error( - self, error: Exception, item: Item, interaction: Interaction + self, error: Exception, item: Item[V], interaction: Interaction ) -> None: """|coro| @@ -414,7 +416,7 @@ async def on_error( error.__class__, error, error.__traceback__, file=sys.stderr ) - async def _scheduled_task(self, item: Item, interaction: Interaction): + async def _scheduled_task(self, item: Item[V], interaction: Interaction): try: if self.timeout: self.__timeout_expiry = time.monotonic() + self.timeout @@ -446,7 +448,7 @@ def _dispatch_timeout(self): self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}" ) - def _dispatch_item(self, item: Item, interaction: Interaction): + def _dispatch_item(self, item: Item[V], interaction: Interaction): if self.__stopped.done(): return @@ -460,10 +462,10 @@ def _dispatch_item(self, item: Item, interaction: Interaction): def refresh(self, components: list[Component]): # This is pretty hacky at the moment - old_state: dict[tuple[int, str], Item] = { + old_state: dict[tuple[int, str], Item[V]] = { (item.type.value, item.custom_id): item for item in self.children if item.is_dispatchable() # type: ignore } - children: list[Item] = [ + children: list[Item[V]] = [ item for item in self.children if not item.is_dispatchable() ] for component in _walk_all_components(components): @@ -529,7 +531,7 @@ async def wait(self) -> bool: """ return await self.__stopped - def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None: + def disable_all_items(self, *, exclusions: list[Item[V]] | None = None) -> None: """ Disables all items in the view. @@ -542,7 +544,7 @@ def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None: if exclusions is None or child not in exclusions: child.disabled = True - def enable_all_items(self, *, exclusions: list[Item] | None = None) -> None: + def enable_all_items(self, *, exclusions: list[Item[V]] | None = None) -> None: """ Enables all items in the view. @@ -567,7 +569,7 @@ def message(self, value): class ViewStore: def __init__(self, state: ConnectionState): # (component_type, message_id, custom_id): (View, Item) - self._views: dict[tuple[int, int | None, str], tuple[View, Item]] = {} + self._views: dict[tuple[int, int | None, str], tuple[View, Item[V]]] = {} # message_id: View self._synced_message_views: dict[int, View] = {} self._state: ConnectionState = state