Skip to content

Commit 876adc2

Browse files
authored
Merge pull request #3331 from b0nes1/main
Implemented optional duration parameter in slowmode command
2 parents 1546ba6 + 730311f commit 876adc2

File tree

2 files changed

+230
-19
lines changed

2 files changed

+230
-19
lines changed

bot/exts/moderation/slowmode.py

Lines changed: 113 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
from datetime import datetime
12
from typing import Literal
23

4+
from async_rediscache import RedisCache
35
from dateutil.relativedelta import relativedelta
46
from discord import TextChannel, Thread
57
from discord.ext.commands import Cog, Context, group, has_any_role
8+
from pydis_core.utils.channel import get_or_fetch_channel
9+
from pydis_core.utils.scheduling import Scheduler
610

711
from bot.bot import Bot
812
from bot.constants import Channels, Emojis, MODERATION_ROLES
9-
from bot.converters import DurationDelta
13+
from bot.converters import Duration, DurationDelta
1014
from bot.log import get_logger
1115
from bot.utils import time
1216

@@ -26,8 +30,14 @@
2630
class Slowmode(Cog):
2731
"""Commands for getting and setting slowmode delays of text channels."""
2832

33+
# RedisCache[discord.channel.id : f"{delay}, {expiry}"]
34+
# `delay` is the slowmode delay assigned to the text channel.
35+
# `expiry` is a naïve ISO 8601 string which describes when the slowmode should be removed.
36+
slowmode_cache = RedisCache()
37+
2938
def __init__(self, bot: Bot) -> None:
3039
self.bot = bot
40+
self.scheduler = Scheduler(self.__class__.__name__)
3141

3242
@group(name="slowmode", aliases=["sm"], invoke_without_command=True)
3343
async def slowmode_group(self, ctx: Context) -> None:
@@ -42,17 +52,29 @@ async def get_slowmode(self, ctx: Context, channel: MessageHolder) -> None:
4252
channel = ctx.channel
4353

4454
humanized_delay = time.humanize_delta(seconds=channel.slowmode_delay)
45-
46-
await ctx.send(f"The slowmode delay for {channel.mention} is {humanized_delay}.")
55+
original_delay, humanized_original_delay, expiration_timestamp = await self._fetch_sm_cache(channel.id)
56+
if original_delay is not None:
57+
await ctx.send(
58+
f"The slowmode delay for {channel.mention} is {humanized_delay}"
59+
f" and will revert to {humanized_original_delay} {expiration_timestamp}."
60+
)
61+
else:
62+
await ctx.send(f"The slowmode delay for {channel.mention} is {humanized_delay}.")
4763

4864
@slowmode_group.command(name="set", aliases=["s"])
4965
async def set_slowmode(
5066
self,
5167
ctx: Context,
5268
channel: MessageHolder,
5369
delay: DurationDelta | Literal["0s", "0seconds"],
70+
expiry: Duration | None = None
5471
) -> None:
55-
"""Set the slowmode delay for a text channel."""
72+
"""
73+
Set the slowmode delay for a text channel.
74+
75+
Supports temporary slowmodes with the `expiry` argument that automatically
76+
revert to the original delay after expiration.
77+
"""
5678
# Use the channel this command was invoked in if one was not given
5779
if channel is None:
5880
channel = ctx.channel
@@ -62,31 +84,96 @@ async def set_slowmode(
6284
if isinstance(delay, str):
6385
delay = relativedelta(seconds=0)
6486

65-
slowmode_delay = time.relativedelta_to_timedelta(delay).total_seconds()
87+
slowmode_delay = int(time.relativedelta_to_timedelta(delay).total_seconds())
6688
humanized_delay = time.humanize_delta(delay)
6789

6890
# Ensure the delay is within discord's limits
69-
if slowmode_delay <= SLOWMODE_MAX_DELAY:
70-
log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.")
71-
72-
await channel.edit(slowmode_delay=slowmode_delay)
73-
if channel.id in COMMONLY_SLOWMODED_CHANNELS:
74-
log.info(f"Recording slowmode change in stats for {channel.name}.")
75-
self.bot.stats.gauge(f"slowmode.{COMMONLY_SLOWMODED_CHANNELS[channel.id]}", slowmode_delay)
91+
if slowmode_delay > SLOWMODE_MAX_DELAY:
92+
log.info(
93+
f"{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, "
94+
"which is not between 0 and 6 hours."
95+
)
7696

7797
await ctx.send(
78-
f"{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}."
98+
f"{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours."
7999
)
100+
return
80101

81-
else:
102+
if expiry is not None:
103+
expiration_timestamp = time.format_relative(expiry)
104+
105+
original_delay, humanized_original_delay, _ = await self._fetch_sm_cache(channel.id)
106+
# Cache the channel's current delay if it has no expiry, otherwise use the cached original delay.
107+
if original_delay is None:
108+
original_delay = channel.slowmode_delay
109+
humanized_original_delay = time.humanize_delta(seconds=original_delay)
110+
else:
111+
self.scheduler.cancel(channel.id)
112+
await self.slowmode_cache.set(channel.id, f"{original_delay}, {expiry}")
113+
114+
self.scheduler.schedule_at(expiry, channel.id, self._revert_slowmode(channel.id))
82115
log.info(
83-
f"{ctx.author} tried to set the slowmode delay of #{channel} to {humanized_delay}, "
84-
"which is not between 0 and 6 hours."
116+
f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}"
117+
f" which will revert to {humanized_original_delay} in {time.humanize_delta(expiry)}."
118+
)
119+
await channel.edit(slowmode_delay=slowmode_delay)
120+
await ctx.send(
121+
f"{Emojis.check_mark} The slowmode delay for {channel.mention}"
122+
f" is now {humanized_delay} and will revert to {humanized_original_delay} {expiration_timestamp}."
85123
)
124+
else:
125+
if await self.slowmode_cache.contains(channel.id):
126+
await self.slowmode_cache.delete(channel.id)
127+
self.scheduler.cancel(channel.id)
86128

129+
log.info(f"{ctx.author} set the slowmode delay for #{channel} to {humanized_delay}.")
130+
await channel.edit(slowmode_delay=slowmode_delay)
87131
await ctx.send(
88-
f"{Emojis.cross_mark} The slowmode delay must be between 0 and 6 hours."
132+
f"{Emojis.check_mark} The slowmode delay for {channel.mention} is now {humanized_delay}."
89133
)
134+
if channel.id in COMMONLY_SLOWMODED_CHANNELS:
135+
log.info(f"Recording slowmode change in stats for {channel.name}.")
136+
self.bot.stats.gauge(f"slowmode.{COMMONLY_SLOWMODED_CHANNELS[channel.id]}", slowmode_delay)
137+
138+
async def _reschedule(self) -> None:
139+
log.trace("Rescheduling the expiration of temporary slowmodes from cache.")
140+
for channel_id, cached_data in await self.slowmode_cache.items():
141+
expiration = cached_data.split(", ")[1]
142+
expiration_datetime = datetime.fromisoformat(expiration)
143+
channel = self.bot.get_channel(channel_id)
144+
log.info(f"Rescheduling slowmode expiration for #{channel} ({channel_id}).")
145+
self.scheduler.schedule_at(expiration_datetime, channel_id, self._revert_slowmode(channel_id))
146+
147+
async def _fetch_sm_cache(self, channel_id: int) -> tuple[int | None, str, str]:
148+
"""
149+
Fetch the channel's info from the cache and decode it.
150+
151+
If no cache for the channel, the returned slowmode is None.
152+
"""
153+
cached_data = await self.slowmode_cache.get(channel_id, None)
154+
if not cached_data:
155+
return None, "", ""
156+
157+
original_delay, expiration_time = cached_data.split(", ")
158+
original_delay = int(original_delay)
159+
humanized_original_delay = time.humanize_delta(seconds=original_delay)
160+
expiration_timestamp = time.format_relative(expiration_time)
161+
162+
return original_delay, humanized_original_delay, expiration_timestamp
163+
164+
async def _revert_slowmode(self, channel_id: int) -> None:
165+
original_delay, humanized_original_delay, _ = await self._fetch_sm_cache(channel_id)
166+
channel = await get_or_fetch_channel(self.bot, channel_id)
167+
mod_channel = await get_or_fetch_channel(self.bot, Channels.mods)
168+
log.info(
169+
f"Slowmode in #{channel.name} ({channel.id}) has expired and has reverted to {humanized_original_delay}."
170+
)
171+
await channel.edit(slowmode_delay=original_delay)
172+
await mod_channel.send(
173+
f"{Emojis.check_mark} A previously applied slowmode in {channel.jump_url} ({channel.id})"
174+
f" has expired and has been reverted to {humanized_original_delay}."
175+
)
176+
await self.slowmode_cache.delete(channel.id)
90177

91178
@slowmode_group.command(name="reset", aliases=["r"])
92179
async def reset_slowmode(self, ctx: Context, channel: MessageHolder) -> None:
@@ -97,6 +184,15 @@ async def cog_check(self, ctx: Context) -> bool:
97184
"""Only allow moderators to invoke the commands in this cog."""
98185
return await has_any_role(*MODERATION_ROLES).predicate(ctx)
99186

187+
async def cog_load(self) -> None:
188+
"""Wait for guild to become available and reschedule slowmodes which should expire."""
189+
await self.bot.wait_until_guild_available()
190+
await self._reschedule()
191+
192+
async def cog_unload(self) -> None:
193+
"""Cancel all scheduled tasks."""
194+
self.scheduler.cancel_all()
195+
100196

101197
async def setup(bot: Bot) -> None:
102198
"""Load the Slowmode cog."""

tests/bot/exts/moderation/test_slowmode.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
import unittest
1+
import asyncio
2+
import datetime
23
from unittest import mock
34

45
from dateutil.relativedelta import relativedelta
56

67
from bot.constants import Emojis
78
from bot.exts.moderation.slowmode import Slowmode
9+
from tests.base import RedisTestCase
810
from tests.helpers import MockBot, MockContext, MockTextChannel
911

1012

11-
class SlowmodeTests(unittest.IsolatedAsyncioTestCase):
13+
class SlowmodeTests(RedisTestCase):
1214

1315
def setUp(self) -> None:
1416
self.bot = MockBot()
@@ -95,6 +97,119 @@ async def test_reset_slowmode_sets_delay_to_zero(self) -> None:
9597
self.ctx, text_channel, relativedelta(seconds=0)
9698
)
9799

100+
@mock.patch("bot.exts.moderation.slowmode.datetime")
101+
async def test_set_slowmode_with_expiry(self, mock_datetime) -> None:
102+
"""Set slowmode with an expiry"""
103+
fixed_datetime = datetime.datetime(2025, 6, 2, 12, 0, 0, tzinfo=datetime.UTC)
104+
mock_datetime.now.return_value = fixed_datetime
105+
106+
test_cases = (
107+
("python-general", 6, 6000, f"{Emojis.check_mark} The slowmode delay for #python-general is now 6 seconds "
108+
"and will revert to 0 seconds <t:1748871600:R>."),
109+
("mod-spam", 5, 600, f"{Emojis.check_mark} The slowmode delay for #mod-spam is now 5 seconds and will "
110+
"revert to 0 seconds <t:1748866200:R>."),
111+
("changelog", 12, 7200, f"{Emojis.check_mark} The slowmode delay for #changelog is now 12 seconds and will "
112+
"revert to 0 seconds <t:1748872800:R>.")
113+
)
114+
for channel_name, seconds, expiry, result_msg in test_cases:
115+
with self.subTest(
116+
channel_mention=channel_name,
117+
seconds=seconds,
118+
expiry=expiry,
119+
result_msg=result_msg
120+
):
121+
text_channel = MockTextChannel(name=channel_name, slowmode_delay=0)
122+
await self.cog.set_slowmode(
123+
self.cog,
124+
self.ctx,
125+
text_channel,
126+
relativedelta(seconds=seconds),
127+
fixed_datetime + relativedelta(seconds=expiry)
128+
)
129+
text_channel.edit.assert_awaited_once_with(slowmode_delay=float(seconds))
130+
self.ctx.send.assert_called_once_with(result_msg)
131+
self.ctx.reset_mock()
132+
133+
async def test_callback_scheduled(self):
134+
"""Schedule slowmode to be reverted"""
135+
self.cog.scheduler=mock.MagicMock(wraps=self.cog.scheduler)
136+
137+
text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123)
138+
expiry = datetime.datetime.now(tz=datetime.UTC) + relativedelta(seconds=10)
139+
await self.cog.set_slowmode(
140+
self.cog,
141+
self.ctx,
142+
text_channel,
143+
relativedelta(seconds=4),
144+
expiry
145+
)
146+
147+
args = (expiry, text_channel.id, mock.ANY)
148+
self.cog.scheduler.schedule_at.assert_called_once_with(*args)
149+
150+
@mock.patch("bot.exts.moderation.slowmode.get_or_fetch_channel")
151+
async def test_revert_slowmode_callback(self, mock_get_or_fetch_channel) -> None:
152+
"""Check that the slowmode is reverted"""
153+
text_channel = MockTextChannel(name="python-general", slowmode_delay=2, id=123, jump_url="#python-general")
154+
mod_channel = MockTextChannel(name="mods", id=999, )
155+
mock_get_or_fetch_channel.side_effect = [text_channel, mod_channel]
156+
157+
await self.cog.set_slowmode(
158+
self.cog,
159+
self.ctx,
160+
text_channel,
161+
relativedelta(seconds=4),
162+
datetime.datetime.now(tz=datetime.UTC) + relativedelta(seconds=10)
163+
)
164+
await self.cog._revert_slowmode(text_channel.id)
165+
text_channel.edit.assert_awaited_with(slowmode_delay=2)
166+
mod_channel.send.assert_called_once_with(
167+
f"{Emojis.check_mark} A previously applied slowmode in {text_channel.jump_url} ({text_channel.id}) "
168+
"has expired and has been reverted to 2 seconds."
169+
)
170+
171+
async def test_reschedule_slowmodes(self) -> None:
172+
"""Does not reschedule if cache is empty"""
173+
self.cog.scheduler.schedule_at = mock.MagicMock()
174+
self.cog._reschedule = mock.AsyncMock()
175+
await self.cog.cog_unload()
176+
await self.cog.cog_load()
177+
178+
self.cog._reschedule.assert_called()
179+
self.cog.scheduler.schedule_at.assert_not_called()
180+
181+
async def test_reschedule_upon_reload(self) -> None:
182+
""" Check that method `_reschedule` is called upon cog reload"""
183+
self.cog._reschedule = mock.AsyncMock(wraps=self.cog._reschedule)
184+
await self.cog.cog_unload()
185+
await self.cog.cog_load()
186+
187+
self.cog._reschedule.assert_called()
188+
189+
async def test_reschedules_slowmodes(self) -> None:
190+
"""Slowmodes are loaded from cache at cog reload and scheduled to be reverted."""
191+
192+
now = datetime.datetime.now(tz=datetime.UTC)
193+
channels = {}
194+
slowmodes = (
195+
(123, (now - datetime.timedelta(minutes=10)), 2), # expiration in the past
196+
(456, (now + datetime.timedelta(minutes=20)), 4), # expiration in the future
197+
)
198+
for channel_id, expiration_datetime, delay in slowmodes:
199+
channel = MockTextChannel(slowmode_delay=delay, id=channel_id)
200+
channels[channel_id] = channel
201+
await self.cog.slowmode_cache.set(channel_id, f"{delay}, {expiration_datetime}")
202+
203+
self.bot.get_channel = mock.MagicMock(side_effect=lambda channel_id: channels.get(channel_id))
204+
await self.cog.cog_unload()
205+
await self.cog.cog_load()
206+
for channel_id in channels:
207+
self.assertIn(channel_id, self.cog.scheduler)
208+
209+
await asyncio.sleep(1) # give scheduled task time to execute
210+
channels[123].edit.assert_awaited_once_with(slowmode_delay=channels[123].slowmode_delay)
211+
channels[456].edit.assert_not_called()
212+
98213
@mock.patch("bot.exts.moderation.slowmode.has_any_role")
99214
@mock.patch("bot.exts.moderation.slowmode.MODERATION_ROLES", new=(1, 2, 3))
100215
async def test_cog_check(self, role_check):

0 commit comments

Comments
 (0)