Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions discord/ui/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import traceback
from functools import partial
from itertools import groupby
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterator, Sequence
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterator, Sequence, TypeVar

from ..components import ActionRow as ActionRowComponent
from ..components import Button as ButtonComponent
Expand All @@ -51,6 +51,8 @@
from ..state import ConnectionState
from ..types.components import Component as ComponentPayload

V = TypeVar("V", bound="View", covariant=True)


def _walk_all_components(components: list[Component]) -> Iterator[Component]:
for item in components:
Expand All @@ -60,7 +62,7 @@ def _walk_all_components(components: list[Component]) -> Iterator[Component]:
yield item


def _component_to_item(component: Component) -> Item:
def _component_to_item(component: Component) -> Item[V]:
if isinstance(component, ButtonComponent):
from .button import Button

Expand All @@ -75,7 +77,7 @@ def _component_to_item(component: Component) -> Item:
class _ViewWeights:
__slots__ = ("weights",)

def __init__(self, children: list[Item]):
def __init__(self, children: list[Item[V]]):
self.weights: list[int] = [0, 0, 0, 0, 0]

key = lambda i: sys.maxsize if i.row is None else i.row
Expand All @@ -84,14 +86,14 @@ def __init__(self, children: list[Item]):
for item in group:
self.add_item(item)

def find_open_space(self, item: Item) -> int:
def find_open_space(self, item: Item[V]) -> int:
for index, weight in enumerate(self.weights):
if weight + item.width <= 5:
return index

raise ValueError("could not find open space for item")

def add_item(self, item: Item) -> None:
def add_item(self, item: Item[V]) -> None:
if item.row is not None:
total = self.weights[item.row] + item.width
if total > 5:
Expand All @@ -105,7 +107,7 @@ def add_item(self, item: Item) -> None:
self.weights[index] += item.width
item._rendered_row = index

def remove_item(self, item: Item) -> None:
def remove_item(self, item: Item[V]) -> None:
if item._rendered_row is not None:
self.weights[item._rendered_row] -= item.width
item._rendered_row = None
Expand Down Expand Up @@ -163,15 +165,15 @@ def __init_subclass__(cls) -> None:

def __init__(
self,
*items: Item,
*items: Item[V],
timeout: float | None = 180.0,
disable_on_timeout: bool = False,
):
self.timeout = timeout
self.disable_on_timeout = disable_on_timeout
self.children: list[Item] = []
self.children: list[Item[V]] = []
for func in self.__view_children_items__:
item: Item = func.__discord_ui_model_type__(
item: Item[V] = func.__discord_ui_model_type__(
**func.__discord_ui_model_kwargs__
)
item.callback = partial(func, self, item)
Expand Down Expand Up @@ -213,7 +215,7 @@ async def __timeout_task_impl(self) -> None:
await asyncio.sleep(self.__timeout_expiry - now)

def to_components(self) -> list[dict[str, Any]]:
def key(item: Item) -> int:
def key(item: Item[V]) -> int:
return item._rendered_row or 0

children = sorted(self.children, key=key)
Expand Down Expand Up @@ -267,7 +269,7 @@ def _expires_at(self) -> float | None:
return time.monotonic() + self.timeout
return None

def add_item(self, item: Item) -> None:
def add_item(self, item: Item[V]) -> None:
"""Adds an item to the view.

Parameters
Expand Down Expand Up @@ -295,7 +297,7 @@ def add_item(self, item: Item) -> None:
item._view = self
self.children.append(item)

def remove_item(self, item: Item) -> None:
def remove_item(self, item: Item[V]) -> None:
"""Removes an item from the view.

Parameters
Expand All @@ -316,7 +318,7 @@ def clear_items(self) -> None:
self.children.clear()
self.__weights.clear()

def get_item(self, custom_id: str) -> Item | None:
def get_item(self, custom_id: str) -> Item[V] | None:
"""Get an item from the view with the given custom ID. Alias for `utils.get(view.children, custom_id=custom_id)`.

Parameters
Expand Down Expand Up @@ -391,7 +393,7 @@ async def on_check_failure(self, interaction: Interaction) -> None:
"""

async def on_error(
self, error: Exception, item: Item, interaction: Interaction
self, error: Exception, item: Item[V], interaction: Interaction
) -> None:
"""|coro|

Expand All @@ -414,7 +416,7 @@ async def on_error(
error.__class__, error, error.__traceback__, file=sys.stderr
)

async def _scheduled_task(self, item: Item, interaction: Interaction):
async def _scheduled_task(self, item: Item[V], interaction: Interaction):
try:
if self.timeout:
self.__timeout_expiry = time.monotonic() + self.timeout
Expand Down Expand Up @@ -446,7 +448,7 @@ def _dispatch_timeout(self):
self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}"
)

def _dispatch_item(self, item: Item, interaction: Interaction):
def _dispatch_item(self, item: Item[V], interaction: Interaction):
if self.__stopped.done():
return

Expand All @@ -460,10 +462,10 @@ def _dispatch_item(self, item: Item, interaction: Interaction):

def refresh(self, components: list[Component]):
# This is pretty hacky at the moment
old_state: dict[tuple[int, str], Item] = {
old_state: dict[tuple[int, str], Item[V]] = {
(item.type.value, item.custom_id): item for item in self.children if item.is_dispatchable() # type: ignore
}
children: list[Item] = [
children: list[Item[V]] = [
item for item in self.children if not item.is_dispatchable()
]
for component in _walk_all_components(components):
Expand Down Expand Up @@ -529,7 +531,7 @@ async def wait(self) -> bool:
"""
return await self.__stopped

def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
def disable_all_items(self, *, exclusions: list[Item[V]] | None = None) -> None:
"""
Disables all items in the view.

Expand All @@ -542,7 +544,7 @@ def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
if exclusions is None or child not in exclusions:
child.disabled = True

def enable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
def enable_all_items(self, *, exclusions: list[Item[V]] | None = None) -> None:
"""
Enables all items in the view.

Expand All @@ -567,7 +569,7 @@ def message(self, value):
class ViewStore:
def __init__(self, state: ConnectionState):
# (component_type, message_id, custom_id): (View, Item)
self._views: dict[tuple[int, int | None, str], tuple[View, Item]] = {}
self._views: dict[tuple[int, int | None, str], tuple[View, Item[V]]] = {}
# message_id: View
self._synced_message_views: dict[int, View] = {}
self._state: ConnectionState = state
Expand Down