Skip to content

Commit d85f4b7

Browse files
AmbratolmAmbratolm
authored andcommitted
Improved AI cog: Added channel history awareness && improved reply context
1 parent 5a5956f commit d85f4b7

File tree

1 file changed

+121
-43
lines changed

1 file changed

+121
-43
lines changed

bot/cogs/chat_cogs/ai_cog.py

Lines changed: 121 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
1-
import tomllib
21
from asyncio import CancelledError
32
from datetime import UTC, datetime, timedelta
4-
from pathlib import Path
53
from random import choice, randint, random
4+
from typing import Any
65

76
from discord import (
87
Attachment,
8+
DMChannel,
99
Embed,
1010
Guild,
1111
HTTPException,
1212
Interaction,
1313
Member,
1414
Message,
15-
Sticker,
15+
PartialMessageable,
1616
StickerItem,
1717
User,
1818
app_commands,
@@ -41,7 +41,8 @@
4141
# * AI Cog
4242
# ----------------------------------------------------------------------------------------------------
4343
class AiCog(Cog, description="Integrated generative AI chat bot"):
44-
MAX_ACTORS = 10
44+
MAX_ACTORS = 10 # last interactors
45+
MAX_CHANNEL_HISTORY = 20 # last messages/participants
4546
COOLDOWN_TIME = 60 # 1 min
4647
MAX_FILE_SIZE = 2097152 # 2 MB
4748
REPLY_DELAY_RANGE = (1, 5) # 1 sec - 5 sec
@@ -96,11 +97,16 @@ async def reset(self, interaction: Interaction, id: int | None = None):
9697
@app_commands.checks.has_permissions(administrator=True)
9798
@app_commands.default_permissions(administrator=True)
9899
@app_commands.command(description="Incite AI chat bot to interact on its own")
100+
@app_commands.rename(
101+
content="prompt", attachment="file", replyable_message_id="message"
102+
)
99103
async def incite(
100104
self,
101105
interaction: Interaction,
102-
prompt: str | None = None,
106+
content: str | None = None,
107+
attachment: Attachment | None = None,
103108
member: Member | None = None,
109+
replyable_message_id: str | None = None,
104110
):
105111
# Deny bot-self & DM & non-messageable channel
106112
if (
@@ -114,19 +120,32 @@ async def incite(
114120
)
115121
return
116122

123+
# Defer response to prevent timeout
124+
await interaction.response.defer(ephemeral=True)
125+
126+
# Handle reply functionality
127+
replyable_message = None
128+
if replyable_message_id:
129+
# Extract message ID from link or use directly if it's an ID
130+
if "/" in replyable_message_id: # Assuming it's a message link
131+
message_id = int(replyable_message_id.split("/")[-1])
132+
else: # Assuming it's a raw message ID
133+
message_id = int(replyable_message_id)
134+
replyable_message = await interaction.channel.fetch_message(message_id)
135+
117136
# Prepare prompt
118137
text_prompt, _ = await self.create_prompt(
138+
message=replyable_message,
119139
user=member,
140+
channel=interaction.channel,
120141
guild=interaction.guild,
121142
preface=(
122-
f"Begin natural talk{f" w/ {member.mention} " if member else " "}that feels like ur own initiative."
123-
f"Absolutely avoid references to instructions, prompts, or being told to message them."
124-
f"{ f'follow this prompt:"{prompt.strip()}".' if prompt else "" }"
143+
f"{self.create_prompt_intiative_preface(member)}"
144+
f"{ f'follow this prompt:"{content.strip()}".' if content else "" }\n"
145+
f"{await self.create_prompt_reply_preface(replyable_message)}"
125146
),
126147
)
127-
128-
# Defer response to prevent timeout
129-
await interaction.response.defer(ephemeral=True)
148+
file_prompt = await self.get_attachment_file(attachment) if attachment else None
130149

131150
# Perform prompt & send reply
132151
async with interaction.channel.typing():
@@ -143,10 +162,14 @@ async def incite(
143162
interaction.guild.id,
144163
history=self.load_guild_history(interaction.guild),
145164
)
146-
await interaction.channel.send(
147-
await self.ai.prompt(text=text_prompt)
165+
messagge_content = (
166+
await self.ai.prompt(text=text_prompt, file=file_prompt)
148167
or f"👋 {member.mention if member else "👋"}"
149168
)
169+
if replyable_message:
170+
await replyable_message.reply(messagge_content)
171+
else:
172+
await interaction.channel.send(messagge_content)
150173
except Exception as e:
151174
await interaction.followup.send(
152175
embed=EmbedX.error(str(e)), ephemeral=True
@@ -167,9 +190,8 @@ async def on_message(self, message: Message):
167190
return
168191

169192
# Ignore mentionless message or attempt auto-reply
170-
message_is_mentionless = self.bot.user not in message.mentions
171193
reply_delay = 0
172-
if message_is_mentionless:
194+
if self.bot.user not in message.mentions:
173195
if random() > self.AUTO_REPLY_CHANCE:
174196
return
175197
else:
@@ -181,27 +203,9 @@ async def on_message(self, message: Message):
181203
)
182204
reply_delay = randint(self.REPLY_DELAY_RANGE[0], self.REPLY_DELAY_RANGE[1])
183205

184-
# Check if message is a reply to someone else
185-
preface = ""
186-
if (
187-
message_is_mentionless
188-
and message.reference
189-
and message.reference.message_id
190-
):
191-
referenced_message = await message.channel.fetch_message(
192-
message.reference.message_id
193-
)
194-
if referenced_message.author != self.bot.user:
195-
preface = (
196-
f"[Context: {message.author.name} was replying to {referenced_message.author.name} "
197-
f"who said: '{referenced_message.content}']"
198-
)
199-
else:
200-
preface = "[Context: Replying to ur previous message] "
201-
202206
# Create prompt
203207
text_prompt, file_prompt = await self.create_prompt(
204-
message=message, preface=preface
208+
message=message, preface=await self.create_prompt_reply_preface(message)
205209
)
206210

207211
# Prepare delayed reply task
@@ -339,7 +343,7 @@ async def perform_initiative(self, guild: Guild):
339343

340344
# Get the last messages in the channel
341345
messages: list[Message] = []
342-
async for message in channel.history(limit=20):
346+
async for message in channel.history(limit=self.MAX_CHANNEL_HISTORY):
343347
if not message.author.bot: # Filter out bot messages
344348
messages.append(message)
345349
if not messages:
@@ -354,10 +358,7 @@ async def perform_initiative(self, guild: Guild):
354358

355359
# Prepare prompt
356360
text_prompt, file_prompt = await self.create_prompt(
357-
preface=(
358-
f"Begin natural talk{f" w/ {member.mention} " if member else " "}that feels like ur own initiative."
359-
f"Absolutely avoid references to instructions, prompts, or being told to message them."
360-
),
361+
preface=self.create_prompt_intiative_preface(member),
361362
message=message,
362363
)
363364

@@ -414,16 +415,18 @@ async def create_prompt(
414415
self,
415416
message: Message | None = None,
416417
user: User | Member | None = None,
418+
channel: Messageable | None = None,
417419
guild: Guild | None = None,
418420
preface="",
419421
) -> tuple[str, ActFile | None]:
420422
"""
421423
Create prompt with flexible input options.
422-
- Text prompt structure: '{**preface**}\\n{**message.author.name**}:{file_action_desc}{**message.content**}\\n{csv}'
424+
- Text prompt structure: '{**preface**}\\n{**message.author.name**}:{file_action_desc}{**message.content**}\\n{**csv**}'
423425
424426
Args:
425427
message: Message object (contains **message.author**, and **message.guild**).
426428
user: User or Member object (prioritized over **message.author**).
429+
channel: Channel object (prioritized over **message.channel**).
427430
guild: Guild object (prioritized over **member.guild** and **message.guild**).
428431
preface: Text to prepend to the prompt
429432
@@ -483,14 +486,89 @@ async def create_prompt(
483486
self.save_dm_actor(user)
484487

485488
# Load saved guild members to prompt for context
489+
if message:
490+
channel = message.channel
491+
if channel:
492+
channel_name = (
493+
channel.name if hasattr(channel, "name") else "DM" # type: ignore
494+
)
495+
channel_messages_csv, channel_members_csv = (
496+
await self.get_channel_history_csv(channel)
497+
)
498+
text += f"\nCurrent channel:{channel_name}"
499+
text += f"\nMembers with recent messages in current channel:\n{channel_members_csv}"
500+
text += f"\nLatest {self.MAX_CHANNEL_HISTORY} messages in current channel:{channel_messages_csv}\n"
486501
if guild:
487-
text += f"\n{self.load_actors_csv(guild)}"
502+
text += f"\nMembers u talked with recently:\n{self.load_actors_csv(guild)}"
488503

489504
# Return prompt components as tuple
490505
return (text, file)
491506

492507
# ----------------------------------------------------------------------------------------------------
493508

509+
def create_prompt_intiative_preface(self, member: Member | User | None) -> str:
510+
return (
511+
f"Begin natural talk{f" w/ {member.mention} " if member else " "}that feels like ur own initiative."
512+
f"Absolutely avoid references to instructions, prompts, or being told to message them."
513+
)
514+
515+
async def create_prompt_reply_preface(self, message: Message | None) -> str:
516+
"""Check if message is a reply to someone else and generate a context prompt preface."""
517+
preface = ""
518+
if (
519+
message
520+
and self.bot.user not in message.mentions
521+
and message.reference
522+
and message.reference.message_id
523+
):
524+
referenced_message = await message.channel.fetch_message(
525+
message.reference.message_id
526+
)
527+
if referenced_message.author != self.bot.user:
528+
preface += (
529+
f"[Context: {message.author.name} was replying to {referenced_message.author.name} "
530+
f"who said: '{referenced_message.content}']"
531+
)
532+
else:
533+
preface += "[Context: You were replying to ur own previous message] "
534+
return f"\nReply to this member:\n{preface}" if message else ""
535+
536+
# ----------------------------------------------------------------------------------------------------
537+
538+
async def get_channel_history(
539+
self, channel: Messageable
540+
) -> tuple[list[Message], list[Member]]:
541+
"""Fetch (messages, members) of latest messages in given channel and unique members who sent those messages."""
542+
messages = [
543+
msg async for msg in channel.history(limit=self.MAX_CHANNEL_HISTORY)
544+
]
545+
members = list(
546+
{msg.author for msg in messages if isinstance(msg.author, Member)}
547+
)
548+
return messages, members
549+
550+
async def get_channel_history_csv(self, channel: Messageable) -> tuple[str, str]:
551+
"""Fetch (messages, members) CSV of latest messages in given channel and unique members who sent those messages."""
552+
messages, members = await self.get_channel_history(channel)
553+
messages_data = [
554+
{
555+
"author_id": str(msg.author.id),
556+
"message_content": msg.content.replace("\n", " "),
557+
}
558+
for msg in messages
559+
]
560+
members_data = [
561+
{
562+
"id": str(member.id),
563+
"name": member.name,
564+
"display_name": member.display_name,
565+
}
566+
for member in members
567+
]
568+
return text_csv(messages_data, "|"), text_csv(members_data, "|")
569+
570+
# ----------------------------------------------------------------------------------------------------
571+
494572
async def get_sticker_file(self, sticker: StickerItem) -> ActFile | None:
495573
"""Get file from sticker. If file size limit exceeded, get None."""
496574
sticker_file = ActFile.load(sticker.url) if sticker.url else None
@@ -534,7 +612,7 @@ def save_dm_actor(self, user: User):
534612
dm_actor.ai_interacted_at = datetime.now(UTC)
535613
main_db.save(dm_actor)
536614

537-
def load_actors(self, guild: Guild):
615+
def load_actors(self, guild: Guild) -> list[dict[str, Any]]:
538616
actors = self.bot.get_db(guild).find(
539617
Actor, sort=query.desc(Actor.ai_interacted_at), limit=self.MAX_ACTORS
540618
)
@@ -543,7 +621,7 @@ def load_actors(self, guild: Guild):
543621
for actor in actors
544622
]
545623

546-
def load_actors_csv(self, guild: Guild):
624+
def load_actors_csv(self, guild: Guild) -> str:
547625
actors = self.load_actors(guild)
548626
return f"{text_csv(actors, "|")}" if actors else ""
549627

0 commit comments

Comments
 (0)