3232import traceback
3333from functools import partial
3434from itertools import groupby
35- from typing import TYPE_CHECKING , Any , Callable , ClassVar , Iterator , Sequence
35+ from typing import TYPE_CHECKING , Any , Callable , ClassVar , Iterator , Sequence , TypeVar
3636
3737from ..components import ActionRow as ActionRowComponent
3838from ..components import Button as ButtonComponent
5151 from ..state import ConnectionState
5252 from ..types .components import Component as ComponentPayload
5353
54+ V = TypeVar ("V" , bound = "View" , covariant = True )
55+
5456
5557def _walk_all_components (components : list [Component ]) -> Iterator [Component ]:
5658 for item in components :
@@ -60,7 +62,7 @@ def _walk_all_components(components: list[Component]) -> Iterator[Component]:
6062 yield item
6163
6264
63- def _component_to_item (component : Component ) -> Item :
65+ def _component_to_item (component : Component ) -> Item [ V ] :
6466 if isinstance (component , ButtonComponent ):
6567 from .button import Button
6668
@@ -75,7 +77,7 @@ def _component_to_item(component: Component) -> Item:
7577class _ViewWeights :
7678 __slots__ = ("weights" ,)
7779
78- def __init__ (self , children : list [Item ]):
80+ def __init__ (self , children : list [Item [ V ] ]):
7981 self .weights : list [int ] = [0 , 0 , 0 , 0 , 0 ]
8082
8183 key = lambda i : sys .maxsize if i .row is None else i .row
@@ -84,14 +86,14 @@ def __init__(self, children: list[Item]):
8486 for item in group :
8587 self .add_item (item )
8688
87- def find_open_space (self , item : Item ) -> int :
89+ def find_open_space (self , item : Item [ V ] ) -> int :
8890 for index , weight in enumerate (self .weights ):
8991 if weight + item .width <= 5 :
9092 return index
9193
9294 raise ValueError ("could not find open space for item" )
9395
94- def add_item (self , item : Item ) -> None :
96+ def add_item (self , item : Item [ V ] ) -> None :
9597 if item .row is not None :
9698 total = self .weights [item .row ] + item .width
9799 if total > 5 :
@@ -103,7 +105,7 @@ def add_item(self, item: Item) -> None:
103105 self .weights [index ] += item .width
104106 item ._rendered_row = index
105107
106- def remove_item (self , item : Item ) -> None :
108+ def remove_item (self , item : Item [ V ] ) -> None :
107109 if item ._rendered_row is not None :
108110 self .weights [item ._rendered_row ] -= item .width
109111 item ._rendered_row = None
@@ -161,15 +163,15 @@ def __init_subclass__(cls) -> None:
161163
162164 def __init__ (
163165 self ,
164- * items : Item ,
166+ * items : Item [ V ] ,
165167 timeout : float | None = 180.0 ,
166168 disable_on_timeout : bool = False ,
167169 ):
168170 self .timeout = timeout
169171 self .disable_on_timeout = disable_on_timeout
170- self .children : list [Item ] = []
172+ self .children : list [Item [ V ] ] = []
171173 for func in self .__view_children_items__ :
172- item : Item = func .__discord_ui_model_type__ (** func .__discord_ui_model_kwargs__ )
174+ item : Item [ V ] = func .__discord_ui_model_type__ (** func .__discord_ui_model_kwargs__ )
173175 item .callback = partial (func , self , item )
174176 item ._view = self
175177 setattr (self , func .__name__ , item )
@@ -209,7 +211,7 @@ async def __timeout_task_impl(self) -> None:
209211 await asyncio .sleep (self .__timeout_expiry - now )
210212
211213 def to_components (self ) -> list [dict [str , Any ]]:
212- def key (item : Item ) -> int :
214+ def key (item : Item [ V ] ) -> int :
213215 return item ._rendered_row or 0
214216
215217 children = sorted (self .children , key = key )
@@ -261,7 +263,7 @@ def _expires_at(self) -> float | None:
261263 return time .monotonic () + self .timeout
262264 return None
263265
264- def add_item (self , item : Item ) -> None :
266+ def add_item (self , item : Item [ V ] ) -> None :
265267 """Adds an item to the view.
266268
267269 Parameters
@@ -289,7 +291,7 @@ def add_item(self, item: Item) -> None:
289291 item ._view = self
290292 self .children .append (item )
291293
292- def remove_item (self , item : Item ) -> None :
294+ def remove_item (self , item : Item [ V ] ) -> None :
293295 """Removes an item from the view.
294296
295297 Parameters
@@ -310,7 +312,7 @@ def clear_items(self) -> None:
310312 self .children .clear ()
311313 self .__weights .clear ()
312314
313- def get_item (self , custom_id : str ) -> Item | None :
315+ def get_item (self , custom_id : str ) -> Item [ V ] | None :
314316 """Get an item from the view with the given custom ID. Alias for `utils.get(view.children, custom_id=custom_id)`.
315317
316318 Parameters
@@ -384,7 +386,7 @@ async def on_check_failure(self, interaction: Interaction) -> None:
384386 The interaction that occurred.
385387 """
386388
387- async def on_error (self , error : Exception , item : Item , interaction : Interaction ) -> None :
389+ async def on_error (self , error : Exception , item : Item [ V ] , interaction : Interaction ) -> None :
388390 """|coro|
389391
390392 A callback that is called when an item's callback or :meth:`interaction_check`
@@ -404,7 +406,7 @@ async def on_error(self, error: Exception, item: Item, interaction: Interaction)
404406 print (f"Ignoring exception in view { self } for item { item } :" , file = sys .stderr )
405407 traceback .print_exception (error .__class__ , error , error .__traceback__ , file = sys .stderr )
406408
407- async def _scheduled_task (self , item : Item , interaction : Interaction ):
409+ async def _scheduled_task (self , item : Item [ V ] , interaction : Interaction ):
408410 try :
409411 if self .timeout :
410412 self .__timeout_expiry = time .monotonic () + self .timeout
@@ -434,7 +436,7 @@ def _dispatch_timeout(self):
434436 self .__stopped .set_result (True )
435437 asyncio .create_task (self .on_timeout (), name = f"discord-ui-view-timeout-{ self .id } " )
436438
437- def _dispatch_item (self , item : Item , interaction : Interaction ):
439+ def _dispatch_item (self , item : Item [ V ] , interaction : Interaction ):
438440 if self .__stopped .done ():
439441 return
440442
@@ -448,12 +450,12 @@ def _dispatch_item(self, item: Item, interaction: Interaction):
448450
449451 def refresh (self , components : list [Component ]):
450452 # This is pretty hacky at the moment
451- old_state : dict [tuple [int , str ], Item ] = {
453+ old_state : dict [tuple [int , str ], Item [ V ] ] = {
452454 (item .type .value , item .custom_id ): item
453455 for item in self .children
454456 if item .is_dispatchable () # type: ignore
455457 }
456- children : list [Item ] = [item for item in self .children if not item .is_dispatchable ()]
458+ children : list [Item [ V ] ] = [item for item in self .children if not item .is_dispatchable ()]
457459 for component in _walk_all_components (components ):
458460 try :
459461 older = old_state [(component .type .value , component .custom_id )] # type: ignore
@@ -515,7 +517,7 @@ async def wait(self) -> bool:
515517 """
516518 return await self .__stopped
517519
518- def disable_all_items (self , * , exclusions : list [Item ] | None = None ) -> None :
520+ def disable_all_items (self , * , exclusions : list [Item [ V ] ] | None = None ) -> None :
519521 """
520522 Disables all items in the view.
521523
@@ -528,7 +530,7 @@ def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
528530 if exclusions is None or child not in exclusions :
529531 child .disabled = True
530532
531- def enable_all_items (self , * , exclusions : list [Item ] | None = None ) -> None :
533+ def enable_all_items (self , * , exclusions : list [Item [ V ] ] | None = None ) -> None :
532534 """
533535 Enables all items in the view.
534536
@@ -553,7 +555,7 @@ def message(self, value):
553555class ViewStore :
554556 def __init__ (self , state : ConnectionState ):
555557 # (component_type, message_id, custom_id): (View, Item)
556- self ._views : dict [tuple [int , int | None , str ], tuple [View , Item ]] = {}
558+ self ._views : dict [tuple [int , int | None , str ], tuple [View , Item [ V ] ]] = {}
557559 # message_id: View
558560 self ._synced_message_views : dict [int , View ] = {}
559561 self ._state : ConnectionState = state
0 commit comments