Skip to content

Commit ea230a1

Browse files
nexy7574Lulalaby
andauthored
Add a native timeout for modal.py (#1434)
* Add a native timeout for modal.py * Fix timeout errors Co-authored-by: Lala Sabathil <[email protected]>
1 parent 4d26ae2 commit ea230a1

File tree

1 file changed

+66
-2
lines changed

1 file changed

+66
-2
lines changed

discord/ui/modal.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
import os
55
import sys
66
import traceback
7+
import time
8+
from functools import partial
79
from itertools import groupby
8-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
10+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Callable
911

1012
from .input_text import InputText
1113

@@ -37,9 +39,14 @@ class Modal:
3739
custom_id: Optional[:class:`str`]
3840
The ID of the modal dialog that gets received during an interaction.
3941
Must be 100 characters or fewer.
42+
timeout: Optional[:class:`float`]
43+
Timeout in seconds from last interaction with the UI before no longer accepting input.
44+
If ``None`` then there is no timeout.
4045
"""
4146

42-
def __init__(self, *children: InputText, title: str, custom_id: Optional[str] = None) -> None:
47+
def __init__(self, *children: InputText, title: str, custom_id: Optional[str] = None,
48+
timeout: Optional[float] = None) -> None:
49+
self.timeout: Optional[float] = timeout
4350
if not isinstance(custom_id, str) and custom_id is not None:
4451
raise TypeError(f"expected custom_id to be str, not {custom_id.__class__.__name__}")
4552
self._custom_id: Optional[str] = custom_id or os.urandom(16).hex()
@@ -50,6 +57,50 @@ def __init__(self, *children: InputText, title: str, custom_id: Optional[str] =
5057
self._weights = _ModalWeights(self._children)
5158
loop = asyncio.get_running_loop()
5259
self._stopped: asyncio.Future[bool] = loop.create_future()
60+
self.__cancel_callback: Optional[Callable[[Modal], None]] = None
61+
self.__timeout_expiry: Optional[float] = None
62+
self.__timeout_task: Optional[asyncio.Task[None]] = None
63+
self.loop = asyncio.get_event_loop()
64+
65+
def _start_listening_from_store(self, store: ModalStore) -> None:
66+
self.__cancel_callback = partial(store.remove_modal)
67+
if self.timeout:
68+
loop = asyncio.get_running_loop()
69+
if self.__timeout_task is not None:
70+
self.__timeout_task.cancel()
71+
72+
self.__timeout_expiry = time.monotonic() + self.timeout
73+
self.__timeout_task = loop.create_task(self.__timeout_task_impl())
74+
75+
async def __timeout_task_impl(self) -> None:
76+
while True:
77+
# Guard just in case someone changes the value of the timeout at runtime
78+
if self.timeout is None:
79+
return
80+
81+
if self.__timeout_expiry is None:
82+
return self._dispatch_timeout()
83+
84+
# Check if we've elapsed our currently set timeout
85+
now = time.monotonic()
86+
if now >= self.__timeout_expiry:
87+
return self._dispatch_timeout()
88+
89+
# Wait N seconds to see if timeout data has been refreshed
90+
await asyncio.sleep(self.__timeout_expiry - now)
91+
92+
@property
93+
def _expires_at(self) -> Optional[float]:
94+
if self.timeout:
95+
return time.monotonic() + self.timeout
96+
return None
97+
98+
def _dispatch_timeout(self):
99+
if self._stopped.done():
100+
return
101+
102+
self._stopped.set_result(True)
103+
self.loop.create_task(self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}")
53104

54105
@property
55106
def title(self) -> str:
@@ -158,6 +209,10 @@ def stop(self) -> None:
158209
"""Stops listening to interaction events from the modal dialog."""
159210
if not self._stopped.done():
160211
self._stopped.set_result(True)
212+
self.__timeout_expiry = None
213+
if self.__timeout_task is not None:
214+
self.__timeout_task.cancel()
215+
self.__timeout_task = None
161216

162217
async def wait(self) -> bool:
163218
"""Waits for the modal dialog to be submitted."""
@@ -187,6 +242,13 @@ async def on_error(self, error: Exception, interaction: Interaction) -> None:
187242
print(f"Ignoring exception in modal {self}:", file=sys.stderr)
188243
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)
189244

245+
async def on_timeout(self) -> None:
246+
"""|coro|
247+
248+
A callback that is called when a modal's timeout elapses without being explicitly stopped.
249+
"""
250+
pass
251+
190252

191253
class _ModalWeights:
192254
__slots__ = ("weights",)
@@ -236,8 +298,10 @@ def __init__(self, state: ConnectionState) -> None:
236298

237299
def add_modal(self, modal: Modal, user_id: int):
238300
self._modals[(user_id, modal.custom_id)] = modal
301+
modal._start_listening_from_store(self)
239302

240303
def remove_modal(self, modal: Modal, user_id):
304+
modal.stop()
241305
self._modals.pop((user_id, modal.custom_id))
242306

243307
async def dispatch(self, user_id: int, custom_id: str, interaction: Interaction):

0 commit comments

Comments
 (0)