Skip to content
This repository was archived by the owner on Mar 1, 2022. It is now read-only.

Commit 95f994d

Browse files
authored
Fix bug in reaction remove
1 parent 0882a2f commit 95f994d

File tree

1 file changed

+123
-44
lines changed

1 file changed

+123
-44
lines changed

discord/ext/buttons/buttons.py

Lines changed: 123 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,25 @@
11
import asyncio
22
import discord
33
import inspect
4+
from concurrent.futures import TimeoutError
45
from discord.ext import commands
56
from functools import partial
67
from typing import Union
78

9+
__all__ = ('Session', 'Paginator', 'button', 'inverse_button',)
10+
11+
12+
class Button:
13+
__slots__ = ('_callback', '_inverse_callback', 'emoji', 'position', 'try_remove')
14+
15+
def __init__(self, **kwargs):
16+
self._callback = kwargs.get('callback')
17+
self._inverse_callback = kwargs.get('inverse_callback')
18+
19+
self.emoji = kwargs.get('emoji')
20+
self.position = kwargs.get('position')
21+
self.try_remove = kwargs.get('try_remove', True)
22+
823

924
class Session:
1025
"""Interactive session class, which uses reactions as buttons.
@@ -27,16 +42,32 @@ def __init__(self, *, timeout: int = 180, try_remove: bool = True):
2742
self.timeout = timeout
2843
self.buttons = self._buttons
2944

45+
self._defaults = {}
46+
3047
def __init_subclass__(cls, **kwargs):
3148
pass
3249

3350
def _gather_buttons(self):
3451
for _, member in inspect.getmembers(self):
35-
if hasattr(member, '__button__'): # Check if the member is a button...
36-
self._buttons[member.__button__[2], member.__button__[0]] = member.__button__[1] # pos, key, value
52+
if hasattr(member, '__button__'):
53+
button = member.__button__
54+
55+
sorted_ = self.sort_buttons(buttons=self._buttons)
56+
try:
57+
button_ = sorted_[button.emoji]
58+
except KeyError:
59+
self._buttons[button.position, button.emoji] = button
60+
continue
61+
62+
if button._inverse_callback:
63+
button_._inverse_callback = button._inverse_callback
64+
else:
65+
button_._callback = button._callback
66+
67+
self._buttons[button.position, button.emoji] = button_
3768

3869
def sort_buttons(self, *, buttons: dict = None):
39-
if not buttons:
70+
if buttons is None:
4071
buttons = self.buttons
4172

4273
return {k[1]: v for k, v in sorted(buttons.items(), key=lambda t: t[0])}
@@ -65,26 +96,49 @@ async def start(self, ctx, page=None):
6596
async def _session(self, ctx):
6697
self.buttons = self.sort_buttons()
6798

68-
for reaction in self.buttons.keys():
69-
ctx.bot.loop.create_task(self._add_reaction(reaction))
99+
ctx.bot.loop.create_task(self._add_reactions(self.buttons.keys()))
70100

101+
await self._session_loop(ctx)
102+
103+
async def _session_loop(self, ctx):
71104
while True:
105+
_add = asyncio.ensure_future(ctx.bot.wait_for('raw_reaction_add', check=lambda _: self.check(_)(ctx)))
106+
_remove = asyncio.ensure_future(ctx.bot.wait_for('raw_reaction_remove', check=lambda _: self.check(_)(ctx)))
107+
108+
done, pending = await asyncio.wait((_add, _remove), return_when=asyncio.FIRST_COMPLETED, timeout=self.timeout)
109+
110+
for future in pending:
111+
future.cancel()
112+
113+
if not done:
114+
return ctx.bot.loop.create_task(self.cancel(ctx))
115+
72116
try:
73-
payload = await ctx.bot.wait_for('raw_reaction_add', timeout=self.timeout,
74-
check=lambda _: self.check(_)(ctx))
75-
except asyncio.TimeoutError:
117+
result = done.pop()
118+
payload = result.result()
119+
120+
if result == _add:
121+
action = True
122+
else:
123+
action = False
124+
except Exception:
76125
return ctx.bot.loop.create_task(self.cancel(ctx))
77126

78-
if self._try_remove:
127+
emoji = self.get_emoji_as_string(payload.emoji)
128+
button = self.buttons[emoji]
129+
130+
if self._try_remove and button.try_remove:
79131
try:
80132
await self.page.remove_reaction(payload.emoji, ctx.guild.get_member(payload.user_id))
81133
except discord.HTTPException:
82134
pass
83135

84-
emoji = self.get_emoji_as_string(payload.emoji)
85-
action = self.buttons[emoji]
86-
87-
await action(self, ctx)
136+
if action and button in self._defaults.values():
137+
await button._callback(ctx)
138+
elif action and button._callback:
139+
await button._callback(self, ctx)
140+
elif not action and button._inverse_callback:
141+
await button._inverse_callback(self, ctx)
88142

89143
@property
90144
def is_cancelled(self):
@@ -101,8 +155,12 @@ async def teardown(self, ctx):
101155
self._session_task.cancel()
102156
await self.page.delete()
103157

104-
async def _add_reaction(self, reaction):
105-
await self.page.add_reaction(reaction)
158+
async def _add_reactions(self, reactions):
159+
for reaction in reactions:
160+
try:
161+
await self.page.add_reaction(reaction)
162+
except discord.NotFound:
163+
pass
106164

107165
def get_emoji_as_string(self, emoji):
108166
return f'{emoji.name}{":" + str(emoji.id) if emoji.is_custom_emoji() else ""}'
@@ -161,11 +219,11 @@ def __init__(self, *, title: str = '', length: int = 10, entries: list = None,
161219
color: Union[int, discord.Colour] = discord.Embed.Empty, use_defaults: bool = True, embed: bool = True,
162220
joiner: str = '\n', timeout: int = 180, thumbnail: str = None):
163221
super().__init__()
164-
self._defaults = {(0, '⏮'): partial(self._default_indexer, 'start'),
165-
(1, '◀'): partial(self._default_indexer, -1),
166-
(2, '⏹'): partial(self._default_indexer, 'stop'),
167-
(3, '▶'): partial(self._default_indexer, +1),
168-
(4, '⏭'): partial(self._default_indexer, 'end')}
222+
self._defaults = {(0, '⏮'): Button(emoji='⏮', position=0, callback=partial(self._default_indexer, 'start')),
223+
(1, '◀'): Button(emoji='◀', position=1, callback=partial(self._default_indexer, -1)),
224+
(2, '⏹'): Button(emoji='⏹', position=2, callback=partial(self._default_indexer, 'stop')),
225+
(3, '▶'): Button(emoji='▶', position=3, callback=partial(self._default_indexer, +1)),
226+
(4, '⏭'): Button(emoji='⏭', position=4, callback=partial(self._default_indexer, 'end'))}
169227

170228
self.buttons = {}
171229

@@ -245,29 +303,9 @@ async def _session(self, ctx):
245303

246304
self.buttons = self.sort_buttons()
247305

248-
for reaction in self.buttons.keys():
249-
ctx.bot.loop.create_task(self._add_reaction(reaction))
250-
251-
while True:
252-
try:
253-
payload = await ctx.bot.wait_for('raw_reaction_add', timeout=self.timeout,
254-
check=lambda _: self.check(_)(ctx))
255-
except asyncio.TimeoutError:
256-
return ctx.bot.loop.create_task(self.cancel(ctx))
257-
258-
if self._try_remove:
259-
try:
260-
await self.page.remove_reaction(payload.emoji, ctx.guild.get_member(payload.user_id))
261-
except discord.HTTPException:
262-
pass
263-
264-
emoji = self.get_emoji_as_string(payload.emoji)
265-
action = self.buttons[emoji]
306+
ctx.bot.loop.create_task(self._add_reactions(self.buttons.keys()))
266307

267-
if action in self._defaults.values():
268-
await action(ctx)
269-
else:
270-
await action(self, ctx)
308+
await self._session_loop(ctx)
271309

272310
async def _default_indexer(self, control, ctx):
273311
previous = self._index
@@ -294,7 +332,7 @@ async def _default_indexer(self, control, ctx):
294332
await self.page.edit(content=self._pages[self._index])
295333

296334

297-
def button(emoji: str, *, position: int = 666):
335+
def button(emoji: str, *, try_remove=True, position: int = 666):
298336
"""A decorator that adds a button to your interactive session class.
299337
300338
Parameters
@@ -315,7 +353,48 @@ def deco(func):
315353
if not asyncio.iscoroutinefunction(func):
316354
raise TypeError('Button callback must be a coroutine.')
317355

318-
func.__button__ = (emoji, func, position)
356+
if hasattr(func, '__button__'):
357+
button = func.__button__
358+
button._callback = func
359+
360+
return func
361+
362+
func.__button__ = Button(emoji=emoji, callback=func, position=position, try_remove=try_remove)
363+
return func
364+
365+
return deco
366+
367+
368+
def inverse_button(emoji: str = None, *, try_remove=False, position: int = 666):
369+
"""A decorator that adds an inverted button to your interactive session class.
370+
371+
The inverse button will work when a reaction is unpressed.
372+
373+
Parameters
374+
-----------
375+
emoji: str
376+
The emoji to use as a button. This could be a unicode endpoint or in name:id format,
377+
for custom emojis.
378+
position: int
379+
The position to inject the button into.
380+
381+
Raises
382+
-------
383+
TypeError
384+
The button callback is not a coroutine.
385+
"""
386+
387+
def deco(func):
388+
if not asyncio.iscoroutinefunction(func):
389+
raise TypeError('Button callback must be a coroutine.')
390+
391+
if hasattr(func, '__button__'):
392+
button = func.__button__
393+
button._inverse_callback = func
394+
395+
return func
396+
397+
func.__button__ = Button(emoji=emoji, inverse_callback=func, position=position, try_remove=try_remove)
319398
return func
320399

321400
return deco

0 commit comments

Comments
 (0)