11import asyncio
22import discord
33import inspect
4+ from concurrent .futures import TimeoutError
45from discord .ext import commands
56from functools import partial
67from typing import Union
78
9+ __all__ = ('Session' , 'Paginator' , 'button' , 'inverse_button' ,)
10+
11+
12+ class Button :
13+ __slots__ = ('_callback' , '_inverse_callback' , 'emoji' , 'position' , 'try_remove' )
14+
15+ def __init__ (self , ** kwargs ):
16+ self ._callback = kwargs .get ('callback' )
17+ self ._inverse_callback = kwargs .get ('inverse_callback' )
18+
19+ self .emoji = kwargs .get ('emoji' )
20+ self .position = kwargs .get ('position' )
21+ self .try_remove = kwargs .get ('try_remove' , True )
22+
823
924class Session :
1025 """Interactive session class, which uses reactions as buttons.
@@ -27,16 +42,32 @@ def __init__(self, *, timeout: int = 180, try_remove: bool = True):
2742 self .timeout = timeout
2843 self .buttons = self ._buttons
2944
45+ self ._defaults = {}
46+
3047 def __init_subclass__ (cls , ** kwargs ):
3148 pass
3249
3350 def _gather_buttons (self ):
3451 for _ , member in inspect .getmembers (self ):
35- if hasattr (member , '__button__' ): # Check if the member is a button...
36- self ._buttons [member .__button__ [2 ], member .__button__ [0 ]] = member .__button__ [1 ] # pos, key, value
52+ if hasattr (member , '__button__' ):
53+ button = member .__button__
54+
55+ sorted_ = self .sort_buttons (buttons = self ._buttons )
56+ try :
57+ button_ = sorted_ [button .emoji ]
58+ except KeyError :
59+ self ._buttons [button .position , button .emoji ] = button
60+ continue
61+
62+ if button ._inverse_callback :
63+ button_ ._inverse_callback = button ._inverse_callback
64+ else :
65+ button_ ._callback = button ._callback
66+
67+ self ._buttons [button .position , button .emoji ] = button_
3768
3869 def sort_buttons (self , * , buttons : dict = None ):
39- if not buttons :
70+ if buttons is None :
4071 buttons = self .buttons
4172
4273 return {k [1 ]: v for k , v in sorted (buttons .items (), key = lambda t : t [0 ])}
@@ -65,26 +96,49 @@ async def start(self, ctx, page=None):
6596 async def _session (self , ctx ):
6697 self .buttons = self .sort_buttons ()
6798
68- for reaction in self .buttons .keys ():
69- ctx .bot .loop .create_task (self ._add_reaction (reaction ))
99+ ctx .bot .loop .create_task (self ._add_reactions (self .buttons .keys ()))
70100
101+ await self ._session_loop (ctx )
102+
103+ async def _session_loop (self , ctx ):
71104 while True :
105+ _add = asyncio .ensure_future (ctx .bot .wait_for ('raw_reaction_add' , check = lambda _ : self .check (_ )(ctx )))
106+ _remove = asyncio .ensure_future (ctx .bot .wait_for ('raw_reaction_remove' , check = lambda _ : self .check (_ )(ctx )))
107+
108+ done , pending = await asyncio .wait ((_add , _remove ), return_when = asyncio .FIRST_COMPLETED , timeout = self .timeout )
109+
110+ for future in pending :
111+ future .cancel ()
112+
113+ if not done :
114+ return ctx .bot .loop .create_task (self .cancel (ctx ))
115+
72116 try :
73- payload = await ctx .bot .wait_for ('raw_reaction_add' , timeout = self .timeout ,
74- check = lambda _ : self .check (_ )(ctx ))
75- except asyncio .TimeoutError :
117+ result = done .pop ()
118+ payload = result .result ()
119+
120+ if result == _add :
121+ action = True
122+ else :
123+ action = False
124+ except Exception :
76125 return ctx .bot .loop .create_task (self .cancel (ctx ))
77126
78- if self ._try_remove :
127+ emoji = self .get_emoji_as_string (payload .emoji )
128+ button = self .buttons [emoji ]
129+
130+ if self ._try_remove and button .try_remove :
79131 try :
80132 await self .page .remove_reaction (payload .emoji , ctx .guild .get_member (payload .user_id ))
81133 except discord .HTTPException :
82134 pass
83135
84- emoji = self .get_emoji_as_string (payload .emoji )
85- action = self .buttons [emoji ]
86-
87- await action (self , ctx )
136+ if action and button in self ._defaults .values ():
137+ await button ._callback (ctx )
138+ elif action and button ._callback :
139+ await button ._callback (self , ctx )
140+ elif not action and button ._inverse_callback :
141+ await button ._inverse_callback (self , ctx )
88142
89143 @property
90144 def is_cancelled (self ):
@@ -101,8 +155,12 @@ async def teardown(self, ctx):
101155 self ._session_task .cancel ()
102156 await self .page .delete ()
103157
104- async def _add_reaction (self , reaction ):
105- await self .page .add_reaction (reaction )
158+ async def _add_reactions (self , reactions ):
159+ for reaction in reactions :
160+ try :
161+ await self .page .add_reaction (reaction )
162+ except discord .NotFound :
163+ pass
106164
107165 def get_emoji_as_string (self , emoji ):
108166 return f'{ emoji .name } { ":" + str (emoji .id ) if emoji .is_custom_emoji () else "" } '
@@ -161,11 +219,11 @@ def __init__(self, *, title: str = '', length: int = 10, entries: list = None,
161219 color : Union [int , discord .Colour ] = discord .Embed .Empty , use_defaults : bool = True , embed : bool = True ,
162220 joiner : str = '\n ' , timeout : int = 180 , thumbnail : str = None ):
163221 super ().__init__ ()
164- self ._defaults = {(0 , '⏮' ): partial (self ._default_indexer , 'start' ),
165- (1 , '◀' ): partial (self ._default_indexer , - 1 ),
166- (2 , '⏹' ): partial (self ._default_indexer , 'stop' ),
167- (3 , '▶' ): partial (self ._default_indexer , + 1 ),
168- (4 , '⏭' ): partial (self ._default_indexer , 'end' )}
222+ self ._defaults = {(0 , '⏮' ): Button ( emoji = '⏮' , position = 0 , callback = partial (self ._default_indexer , 'start' ) ),
223+ (1 , '◀' ): Button ( emoji = '◀' , position = 1 , callback = partial (self ._default_indexer , - 1 ) ),
224+ (2 , '⏹' ): Button ( emoji = '⏹' , position = 2 , callback = partial (self ._default_indexer , 'stop' ) ),
225+ (3 , '▶' ): Button ( emoji = '▶' , position = 3 , callback = partial (self ._default_indexer , + 1 ) ),
226+ (4 , '⏭' ): Button ( emoji = '⏭' , position = 4 , callback = partial (self ._default_indexer , 'end' ) )}
169227
170228 self .buttons = {}
171229
@@ -245,29 +303,9 @@ async def _session(self, ctx):
245303
246304 self .buttons = self .sort_buttons ()
247305
248- for reaction in self .buttons .keys ():
249- ctx .bot .loop .create_task (self ._add_reaction (reaction ))
250-
251- while True :
252- try :
253- payload = await ctx .bot .wait_for ('raw_reaction_add' , timeout = self .timeout ,
254- check = lambda _ : self .check (_ )(ctx ))
255- except asyncio .TimeoutError :
256- return ctx .bot .loop .create_task (self .cancel (ctx ))
257-
258- if self ._try_remove :
259- try :
260- await self .page .remove_reaction (payload .emoji , ctx .guild .get_member (payload .user_id ))
261- except discord .HTTPException :
262- pass
263-
264- emoji = self .get_emoji_as_string (payload .emoji )
265- action = self .buttons [emoji ]
306+ ctx .bot .loop .create_task (self ._add_reactions (self .buttons .keys ()))
266307
267- if action in self ._defaults .values ():
268- await action (ctx )
269- else :
270- await action (self , ctx )
308+ await self ._session_loop (ctx )
271309
272310 async def _default_indexer (self , control , ctx ):
273311 previous = self ._index
@@ -294,7 +332,7 @@ async def _default_indexer(self, control, ctx):
294332 await self .page .edit (content = self ._pages [self ._index ])
295333
296334
297- def button (emoji : str , * , position : int = 666 ):
335+ def button (emoji : str , * , try_remove = True , position : int = 666 ):
298336 """A decorator that adds a button to your interactive session class.
299337
300338 Parameters
@@ -315,7 +353,48 @@ def deco(func):
315353 if not asyncio .iscoroutinefunction (func ):
316354 raise TypeError ('Button callback must be a coroutine.' )
317355
318- func .__button__ = (emoji , func , position )
356+ if hasattr (func , '__button__' ):
357+ button = func .__button__
358+ button ._callback = func
359+
360+ return func
361+
362+ func .__button__ = Button (emoji = emoji , callback = func , position = position , try_remove = try_remove )
363+ return func
364+
365+ return deco
366+
367+
368+ def inverse_button (emoji : str = None , * , try_remove = False , position : int = 666 ):
369+ """A decorator that adds an inverted button to your interactive session class.
370+
371+ The inverse button will work when a reaction is unpressed.
372+
373+ Parameters
374+ -----------
375+ emoji: str
376+ The emoji to use as a button. This could be a unicode endpoint or in name:id format,
377+ for custom emojis.
378+ position: int
379+ The position to inject the button into.
380+
381+ Raises
382+ -------
383+ TypeError
384+ The button callback is not a coroutine.
385+ """
386+
387+ def deco (func ):
388+ if not asyncio .iscoroutinefunction (func ):
389+ raise TypeError ('Button callback must be a coroutine.' )
390+
391+ if hasattr (func , '__button__' ):
392+ button = func .__button__
393+ button ._inverse_callback = func
394+
395+ return func
396+
397+ func .__button__ = Button (emoji = emoji , inverse_callback = func , position = position , try_remove = try_remove )
319398 return func
320399
321400 return deco
0 commit comments