Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ however, insignificant breaking changes do not guarantee a major version bump, s

# v4.1.1

### Breaking
- Modmail threads are now potentially Discord threads

### Fixed
- `?msglink` now supports threads with multiple recipients. ([PR #3341](https://github.com/modmail-dev/Modmail/pull/3341))
- Fixed persistent notes not working due to discord.py internal change. ([PR #3324](https://github.com/modmail-dev/Modmail/pull/3324))
Expand Down
37 changes: 23 additions & 14 deletions bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,13 +305,21 @@ def log_channel(self) -> typing.Optional[discord.TextChannel]:
logger.debug("LOG_CHANNEL_ID was invalid, removed.")
self.config.remove("log_channel_id")
if self.main_category is not None:
try:
channel = self.main_category.channels[0]
self.config["log_channel_id"] = channel.id
logger.warning("No log channel set, setting #%s to be the log channel.", channel.name)
return channel
except IndexError:
pass
if isinstance(self.main_category, discord.CategoryChannel):
try:
channel = self.main_category.channels[0]
self.config["log_channel_id"] = channel.id
logger.warning("No log channel set, setting #%s to be the log channel.", channel.name)
return channel
except IndexError:
pass
elif isinstance(self.main_category, discord.TextChannel):
self.config["log_channel_id"] = self.main_category.id
logger.warning(
"No log channel set, setting #%s to be the log channel.", self.main_category.name
)
return self.main_category

logger.warning(
"No log channel set, set one with `%ssetup` or `%sconfig set log_channel_id <id>`.",
self.prefix,
Expand Down Expand Up @@ -419,13 +427,13 @@ def using_multiple_server_setup(self) -> bool:
return self.modmail_guild != self.guild

@property
def main_category(self) -> typing.Optional[discord.CategoryChannel]:
def main_category(self) -> typing.Optional[discord.abc.GuildChannel]:
if self.modmail_guild is not None:
category_id = self.config["main_category_id"]
if category_id is not None:
try:
cat = discord.utils.get(self.modmail_guild.categories, id=int(category_id))
if cat is not None:
cat = discord.utils.get(self.modmail_guild.channels, id=int(category_id))
if cat is not None and isinstance(cat, (discord.CategoryChannel, discord.TextChannel)):
return cat
except ValueError:
pass
Expand Down Expand Up @@ -1351,11 +1359,12 @@ async def on_guild_channel_delete(self, channel):
if channel.guild != self.modmail_guild:
return

if self.main_category == channel:
logger.debug("Main category was deleted.")
self.config.remove("main_category_id")
await self.config.update()

if isinstance(channel, discord.CategoryChannel):
if self.main_category == channel:
logger.debug("Main category was deleted.")
self.config.remove("main_category_id")
await self.config.update()
return

if not isinstance(channel, discord.TextChannel):
Expand Down
9 changes: 9 additions & 0 deletions cogs/modmail.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,9 @@ async def unsubscribe(self, ctx, *, user_or_role: Union[discord.Role, User, str.
@checks.thread_only()
async def nsfw(self, ctx):
"""Flags a Modmail thread as NSFW (not safe for work)."""
if isinstance(ctx.channel, discord.Thread):
await ctx.send("Unable to set NSFW status for Discord threads.")
return
await ctx.channel.edit(nsfw=True)
sent_emoji, _ = await self.bot.retrieve_emoji()
await self.bot.add_reaction(ctx.message, sent_emoji)
Expand All @@ -687,6 +690,9 @@ async def nsfw(self, ctx):
@checks.thread_only()
async def sfw(self, ctx):
"""Flags a Modmail thread as SFW (safe for work)."""
if isinstance(ctx.channel, discord.Thread):
await ctx.send("Unable to set NSFW status for Discord threads.")
return
await ctx.channel.edit(nsfw=False)
sent_emoji, _ = await self.bot.retrieve_emoji()
await self.bot.add_reaction(ctx.message, sent_emoji)
Expand Down Expand Up @@ -775,6 +781,9 @@ def format_log_embeds(self, logs, avatar_url):
@commands.cooldown(1, 600, BucketType.channel)
async def title(self, ctx, *, name: str):
"""Sets title for a thread"""
if isinstance(ctx.channel, discord.Thread):
await ctx.send("Unable to set titles for Discord threads.")
return
await ctx.thread.set_title(name)
sent_emoji, _ = await self.bot.retrieve_emoji()
await ctx.message.pin()
Expand Down
96 changes: 63 additions & 33 deletions core/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,17 @@ async def from_channel(cls, manager: "ThreadManager", channel: discord.TextChann

async def get_genesis_message(self) -> discord.Message:
if self._genesis_message is None:
async for m in self.channel.history(limit=5, oldest_first=True):
if m.author == self.bot.user:
if m.embeds and m.embeds[0].fields and m.embeds[0].fields[0].name == "Roles":
self._genesis_message = m
self._genesis_message = await self._get_genesis_message(self.channel, self.bot.user)

return self._genesis_message

@staticmethod
async def _get_genesis_message(channel, own_user) -> discord.Message | None:
async for m in channel.history(limit=5, oldest_first=True):
if m.author == own_user:
if m.embeds and m.embeds[0].fields and m.embeds[0].fields[0].name == "Roles":
return m

async def setup(self, *, creator=None, category=None, initial_message=None):
"""Create the thread channel and other io related initialisation tasks"""
self.bot.dispatch("thread_initiate", self, creator, category, initial_message)
Expand Down Expand Up @@ -434,9 +438,11 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,
self.channel.id,
{
"open": False,
"title": match_title(self.channel.topic),
"title": match_title(self.channel.topic)
if isinstance(self.channel, discord.TextChannel)
else None,
"closed_at": str(discord.utils.utcnow()),
"nsfw": self.channel.nsfw,
"nsfw": self.channel.nsfw if isinstance(self.channel, discord.TextChannel) else False,
"close_message": message,
"closer": {
"id": str(closer.id),
Expand Down Expand Up @@ -466,7 +472,7 @@ async def _close(self, closer, silent=False, delete_channel=True, message=None,
else:
sneak_peak = "No content"

if self.channel.nsfw:
if isinstance(self.channel, discord.TextChannel) and self.channel.nsfw:
_nsfw = "NSFW-"
else:
_nsfw = ""
Expand Down Expand Up @@ -1230,39 +1236,39 @@ async def _update_users_genesis(self):
await genesis_message.edit(embed=embed)

async def add_users(self, users: typing.List[typing.Union[discord.Member, discord.User]]) -> None:
topic = ""
title, _, _ = parse_channel_topic(self.channel.topic)
if title is not None:
topic += f"Title: {title}\n"

topic += f"User ID: {self._id}"

self._other_recipients += users
self._other_recipients = list(set(self._other_recipients))
if isinstance(self.channel, discord.TextChannel):
topic = ""
title, _, _ = parse_channel_topic(self.channel.topic)
if title is not None:
topic += f"Title: {title}\n"

ids = ",".join(str(i.id) for i in self._other_recipients)
topic += f"User ID: {self._id}"

topic += f"\nOther Recipients: {ids}"
ids = ",".join(str(i.id) for i in self._other_recipients)

await self.channel.edit(topic=topic)
topic += f"\nOther Recipients: {ids}"

await self.channel.edit(topic=topic)
await self._update_users_genesis()

async def remove_users(self, users: typing.List[typing.Union[discord.Member, discord.User]]) -> None:
topic = ""
title, user_id, _ = parse_channel_topic(self.channel.topic)
if title is not None:
topic += f"Title: {title}\n"

topic += f"User ID: {user_id}"

for u in users:
self._other_recipients.remove(u)
if isinstance(self.channel, discord.TextChannel):
topic = ""
title, user_id, _ = parse_channel_topic(self.channel.topic)
if title is not None:
topic += f"Title: {title}\n"

if self._other_recipients:
ids = ",".join(str(i.id) for i in self._other_recipients)
topic += f"\nOther Recipients: {ids}"
topic += f"User ID: {user_id}"

await self.channel.edit(topic=topic)
if self._other_recipients:
ids = ",".join(str(i.id) for i in self._other_recipients)
topic += f"\nOther Recipients: {ids}"

await self.channel.edit(topic=topic)
await self._update_users_genesis()


Expand All @@ -1276,6 +1282,13 @@ def __init__(self, bot):
async def populate_cache(self) -> None:
for channel in self.bot.modmail_guild.text_channels:
await self.find(channel=channel)
for thread in self.bot.modmail_guild.threads:
await self.find(channel=thread)
# handle any threads archived while bot was offline (is this slow? yes. whatever....)
# (maybe this should only iterate until the archived_at timestamp is fine)
if isinstance(self.bot.main_category, discord.TextChannel):
async for thread in self.bot.main_category.archived_threads():
await self.find(channel=thread)

def __len__(self):
return len(self.cache)
Expand All @@ -1290,11 +1303,15 @@ async def find(
self,
*,
recipient: typing.Union[discord.Member, discord.User] = None,
channel: discord.TextChannel = None,
channel: discord.TextChannel | discord.Thread = None,
recipient_id: int = None,
) -> typing.Optional[Thread]:
"""Finds a thread from cache or from discord channel topics."""
if recipient is None and channel is not None and isinstance(channel, discord.TextChannel):
if (
recipient is None
and channel is not None
and isinstance(channel, (discord.TextChannel, discord.Thread))
):
thread = await self._find_from_channel(channel)
if thread is None:
user_id, thread = next(
Expand Down Expand Up @@ -1357,10 +1374,23 @@ async def _find_from_channel(self, channel):
extracts user_id from that.
"""

if not channel.topic:
return None
if isinstance(channel, discord.Thread) or not channel.topic:
# actually check for genesis embed :)
msg = await Thread._get_genesis_message(channel, self.bot.user)
if not msg:
return None

_, user_id, other_ids = parse_channel_topic(channel.topic)
embed = msg.embeds[0]
user_id = int((embed.footer.text or "-1").removeprefix("User ID: ").split(" ", 1)[0])
other_ids = []
for field in embed.fields:
if field.name == "Other Recipients" and field.value:
other_ids = map(
lambda mention: int(mention.removeprefix("<@").removeprefix("!").removesuffix(">")),
field.value.split(" "),
)
else:
_, user_id, other_ids = parse_channel_topic(channel.topic)

if user_id == -1:
return None
Expand Down
18 changes: 11 additions & 7 deletions core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,17 @@ async def create_thread_channel(bot, recipient, category, overwrites, *, name=No
errors_raised = errors_raised or []

try:
channel = await bot.modmail_guild.create_text_channel(
name=name,
category=category,
overwrites=overwrites,
topic=f"User ID: {recipient.id}",
reason="Creating a thread channel.",
)
if isinstance(category, discord.TextChannel):
# we ignore `overwrites`... maybe make private threads so it's similar?
channel = await category.create_thread(name=name, reason="Creating a thread channel.")
else:
channel = await bot.modmail_guild.create_text_channel(
name=name,
category=category,
overwrites=overwrites,
topic=f"User ID: {recipient.id}",
reason="Creating a thread channel.",
)
except discord.HTTPException as e:
if (e.text, (category, name)) in errors_raised:
# Just raise the error to prevent infinite recursion after retrying
Expand Down