32
32
import traceback
33
33
from functools import partial
34
34
from itertools import groupby
35
- from typing import TYPE_CHECKING , Any , Callable , ClassVar , Iterator , Sequence
35
+ from typing import TYPE_CHECKING , Any , Callable , ClassVar , Iterator , Sequence , TypeVar
36
36
37
37
from ..components import ActionRow as ActionRowComponent
38
38
from ..components import Button as ButtonComponent
51
51
from ..state import ConnectionState
52
52
from ..types .components import Component as ComponentPayload
53
53
54
+ V = TypeVar ("V" , bound = "View" , covariant = True )
55
+
54
56
55
57
def _walk_all_components (components : list [Component ]) -> Iterator [Component ]:
56
58
for item in components :
@@ -60,7 +62,7 @@ def _walk_all_components(components: list[Component]) -> Iterator[Component]:
60
62
yield item
61
63
62
64
63
- def _component_to_item (component : Component ) -> Item :
65
+ def _component_to_item (component : Component ) -> Item [ V ] :
64
66
if isinstance (component , ButtonComponent ):
65
67
from .button import Button
66
68
@@ -75,7 +77,7 @@ def _component_to_item(component: Component) -> Item:
75
77
class _ViewWeights :
76
78
__slots__ = ("weights" ,)
77
79
78
- def __init__ (self , children : list [Item ]):
80
+ def __init__ (self , children : list [Item [ V ] ]):
79
81
self .weights : list [int ] = [0 , 0 , 0 , 0 , 0 ]
80
82
81
83
key = lambda i : sys .maxsize if i .row is None else i .row
@@ -84,14 +86,14 @@ def __init__(self, children: list[Item]):
84
86
for item in group :
85
87
self .add_item (item )
86
88
87
- def find_open_space (self , item : Item ) -> int :
89
+ def find_open_space (self , item : Item [ V ] ) -> int :
88
90
for index , weight in enumerate (self .weights ):
89
91
if weight + item .width <= 5 :
90
92
return index
91
93
92
94
raise ValueError ("could not find open space for item" )
93
95
94
- def add_item (self , item : Item ) -> None :
96
+ def add_item (self , item : Item [ V ] ) -> None :
95
97
if item .row is not None :
96
98
total = self .weights [item .row ] + item .width
97
99
if total > 5 :
@@ -103,7 +105,7 @@ def add_item(self, item: Item) -> None:
103
105
self .weights [index ] += item .width
104
106
item ._rendered_row = index
105
107
106
- def remove_item (self , item : Item ) -> None :
108
+ def remove_item (self , item : Item [ V ] ) -> None :
107
109
if item ._rendered_row is not None :
108
110
self .weights [item ._rendered_row ] -= item .width
109
111
item ._rendered_row = None
@@ -161,15 +163,15 @@ def __init_subclass__(cls) -> None:
161
163
162
164
def __init__ (
163
165
self ,
164
- * items : Item ,
166
+ * items : Item [ V ] ,
165
167
timeout : float | None = 180.0 ,
166
168
disable_on_timeout : bool = False ,
167
169
):
168
170
self .timeout = timeout
169
171
self .disable_on_timeout = disable_on_timeout
170
- self .children : list [Item ] = []
172
+ self .children : list [Item [ V ] ] = []
171
173
for func in self .__view_children_items__ :
172
- item : Item = func .__discord_ui_model_type__ (** func .__discord_ui_model_kwargs__ )
174
+ item : Item [ V ] = func .__discord_ui_model_type__ (** func .__discord_ui_model_kwargs__ )
173
175
item .callback = partial (func , self , item )
174
176
item ._view = self
175
177
setattr (self , func .__name__ , item )
@@ -209,7 +211,7 @@ async def __timeout_task_impl(self) -> None:
209
211
await asyncio .sleep (self .__timeout_expiry - now )
210
212
211
213
def to_components (self ) -> list [dict [str , Any ]]:
212
- def key (item : Item ) -> int :
214
+ def key (item : Item [ V ] ) -> int :
213
215
return item ._rendered_row or 0
214
216
215
217
children = sorted (self .children , key = key )
@@ -261,7 +263,7 @@ def _expires_at(self) -> float | None:
261
263
return time .monotonic () + self .timeout
262
264
return None
263
265
264
- def add_item (self , item : Item ) -> None :
266
+ def add_item (self , item : Item [ V ] ) -> None :
265
267
"""Adds an item to the view.
266
268
267
269
Parameters
@@ -289,7 +291,7 @@ def add_item(self, item: Item) -> None:
289
291
item ._view = self
290
292
self .children .append (item )
291
293
292
- def remove_item (self , item : Item ) -> None :
294
+ def remove_item (self , item : Item [ V ] ) -> None :
293
295
"""Removes an item from the view.
294
296
295
297
Parameters
@@ -310,7 +312,7 @@ def clear_items(self) -> None:
310
312
self .children .clear ()
311
313
self .__weights .clear ()
312
314
313
- def get_item (self , custom_id : str ) -> Item | None :
315
+ def get_item (self , custom_id : str ) -> Item [ V ] | None :
314
316
"""Get an item from the view with the given custom ID. Alias for `utils.find(lambda i: i.custom_id == custom_id, self.children)`.
315
317
316
318
Parameters
@@ -384,7 +386,7 @@ async def on_check_failure(self, interaction: Interaction) -> None:
384
386
The interaction that occurred.
385
387
"""
386
388
387
- async def on_error (self , error : Exception , item : Item , interaction : Interaction ) -> None :
389
+ async def on_error (self , error : Exception , item : Item [ V ] , interaction : Interaction ) -> None :
388
390
"""|coro|
389
391
390
392
A callback that is called when an item's callback or :meth:`interaction_check`
@@ -404,7 +406,7 @@ async def on_error(self, error: Exception, item: Item, interaction: Interaction)
404
406
print (f"Ignoring exception in view { self } for item { item } :" , file = sys .stderr )
405
407
traceback .print_exception (error .__class__ , error , error .__traceback__ , file = sys .stderr )
406
408
407
- async def _scheduled_task (self , item : Item , interaction : Interaction ):
409
+ async def _scheduled_task (self , item : Item [ V ] , interaction : Interaction ):
408
410
try :
409
411
if self .timeout :
410
412
self .__timeout_expiry = time .monotonic () + self .timeout
@@ -434,7 +436,7 @@ def _dispatch_timeout(self):
434
436
self .__stopped .set_result (True )
435
437
asyncio .create_task (self .on_timeout (), name = f"discord-ui-view-timeout-{ self .id } " )
436
438
437
- def _dispatch_item (self , item : Item , interaction : Interaction ):
439
+ def _dispatch_item (self , item : Item [ V ] , interaction : Interaction ):
438
440
if self .__stopped .done ():
439
441
return
440
442
@@ -448,12 +450,12 @@ def _dispatch_item(self, item: Item, interaction: Interaction):
448
450
449
451
def refresh (self , components : list [Component ]):
450
452
# This is pretty hacky at the moment
451
- old_state : dict [tuple [int , str ], Item ] = {
453
+ old_state : dict [tuple [int , str ], Item [ V ] ] = {
452
454
(item .type .value , item .custom_id ): item
453
455
for item in self .children
454
456
if item .is_dispatchable () # type: ignore
455
457
}
456
- children : list [Item ] = [item for item in self .children if not item .is_dispatchable ()]
458
+ children : list [Item [ V ] ] = [item for item in self .children if not item .is_dispatchable ()]
457
459
for component in _walk_all_components (components ):
458
460
try :
459
461
older = old_state [(component .type .value , component .custom_id )] # type: ignore
@@ -515,7 +517,7 @@ async def wait(self) -> bool:
515
517
"""
516
518
return await self .__stopped
517
519
518
- def disable_all_items (self , * , exclusions : list [Item ] | None = None ) -> None :
520
+ def disable_all_items (self , * , exclusions : list [Item [ V ] ] | None = None ) -> None :
519
521
"""
520
522
Disables all items in the view.
521
523
@@ -528,7 +530,7 @@ def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
528
530
if exclusions is None or child not in exclusions :
529
531
child .disabled = True
530
532
531
- def enable_all_items (self , * , exclusions : list [Item ] | None = None ) -> None :
533
+ def enable_all_items (self , * , exclusions : list [Item [ V ] ] | None = None ) -> None :
532
534
"""
533
535
Enables all items in the view.
534
536
@@ -553,7 +555,7 @@ def message(self, value):
553
555
class ViewStore :
554
556
def __init__ (self , state : ConnectionState ):
555
557
# (component_type, message_id, custom_id): (View, Item)
556
- self ._views : dict [tuple [int , int | None , str ], tuple [View , Item ]] = {}
558
+ self ._views : dict [tuple [int , int | None , str ], tuple [View , Item [ V ] ]] = {}
557
559
# message_id: View
558
560
self ._synced_message_views : dict [int , View ] = {}
559
561
self ._state : ConnectionState = state
0 commit comments