-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
496 lines (409 loc) · 21.4 KB
/
main.py
File metadata and controls
496 lines (409 loc) · 21.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
import discord
from discord.ext import commands
import openai
import os
import json # <-- Required for character sheet logic
from dotenv import load_dotenv
import asyncio
import re
import tiktoken
from collections import defaultdict
import time
import logging
from logging.handlers import RotatingFileHandler
from datetime import datetime, timedelta
import string
from openai import OpenAI # Make sure to add this at the top with other imports
from cogs.gameplay import roll_dice, create_default_character_sheet
from flask import Flask
from threading import Thread
# --- Keeping Brian Alive ---
app = Flask('')
@app.route('/')
def health_check():
return "Brian's heart is beating. He's alive!", 200
def run_flask():
# Railway/Heroku will provide the port, otherwise default to 8080
port = int(os.environ.get('PORT', 8080))
app.run(host='0.0.0.0', port=port)
# --- Logging Setup ---
def setup_logging():
if not os.path.exists('logs'):
os.makedirs('logs')
logger = logging.getLogger()
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler = RotatingFileHandler('logs/brian_bot.log', maxBytes=5*1024*1024, backupCount=5, encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
logger = setup_logging()
# --- Load Environment Variables ---
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")
DATA_DIR = os.getenv('DATA_DIR', 'characters')
SESSION_NOTES_CHANNEL = os.getenv('SESSION_NOTES_CHANNEL', 'session-notes')
COMMAND_PREFIX = os.getenv('COMMAND_PREFIX', '!')
MODEL_NAME = os.getenv('MODEL_NAME', 'gpt-4')
MAX_TOKENS_FOR_RESPONSE = int(os.getenv('MAX_TOKENS_FOR_RESPONSE', '1500'))
RATE_LIMIT_MENTIONS = int(os.getenv('RATE_LIMIT_MENTIONS', '5'))
RATE_LIMIT_COMMANDS = int(os.getenv('RATE_LIMIT_COMMANDS', '10'))
RATE_LIMIT_WINDOW = int(os.getenv('RATE_LIMIT_WINDOW', '60')) # in seconds
def parse_id_list(env_var: str) -> list:
"""Parse a comma-separated list of IDs from an environment variable."""
if not env_var:
return []
return [int(id_str.strip()) for id_str in env_var.split(',') if id_str.strip()]
# Parse role and channel IDs from environment variables
SEARCHABLE_CHANNEL_IDS = parse_id_list(os.getenv('SEARCHABLE_CHANNEL_IDS', ''))
ALLOWED_ROLES = parse_id_list(os.getenv('ALLOWED_ROLES', ''))
ADMIN_ROLES = parse_id_list(os.getenv('ADMIN_ROLES', ''))
# --- OpenAI & Bot Initialization ---
if not OPENAI_API_KEY or not DISCORD_TOKEN:
logger.critical("FATAL: DISCORD_TOKEN or OPENAI_API_KEY not found in .env file!")
exit()
try:
logger.info("Initializing OpenAI client...")
openai_client = OpenAI(api_key=OPENAI_API_KEY, timeout=30)
logger.info("OpenAI client initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize OpenAI client: {str(e)}")
exit()
intents = discord.Intents.default()
intents.messages = True
intents.message_content = True
intents.guilds = True
bot = commands.Bot(command_prefix=COMMAND_PREFIX, intents=intents)
# --- Configuration & Global State ---
BRIAN_SYSTEM_PROMPT = ""
INSTRUCTIONS_FILE_NAME = "brian_instructions.txt"
# --- Rate Limiting Setup ---
class RateLimiter:
def __init__(self, max_requests: int, time_window: int):
self.max_requests = max_requests
self.time_window = time_window # in seconds
self.requests = defaultdict(list)
def is_rate_limited(self, user_id: int) -> bool:
now = datetime.now()
user_requests = self.requests[user_id]
# Remove old requests
user_requests = [req_time for req_time in user_requests
if now - req_time < timedelta(seconds=self.time_window)]
self.requests[user_id] = user_requests
if len(user_requests) >= self.max_requests:
return True
user_requests.append(now)
return False
# Initialize rate limiters
mention_limiter = RateLimiter(max_requests=RATE_LIMIT_MENTIONS, time_window=RATE_LIMIT_WINDOW)
command_limiter = RateLimiter(max_requests=RATE_LIMIT_COMMANDS, time_window=RATE_LIMIT_WINDOW)
# --- Input Validation ---
def sanitize_input(text: str) -> str:
"""Sanitize user input to prevent injection attacks while preserving conversation context.
This function:
- Removes only potentially harmful control characters
- Preserves all valid Discord formatting (markdown, mentions, emojis)
- Preserves all printable characters including spaces and punctuation
- Limits length to prevent abuse while keeping normal conversation intact
"""
if not text:
return ""
# Remove only potentially harmful control characters while preserving formatting
# This preserves: markdown, mentions, emojis, and all normal conversation elements
text = ''.join(char for char in text if char in string.printable or char in '\n\r\t')
# Limit length to prevent abuse while keeping normal conversation intact
# Discord's message limit is 2000, we use 1000 to leave room for bot's response formatting
return text[:1000]
def validate_channel_name(name: str) -> bool:
"""Validate channel name format."""
# Discord channel names can only contain lowercase letters, numbers, and hyphens
return bool(re.match(r'^[a-z0-9-]+$', name.lower()))
def has_permission(member: discord.Member, required_roles: list) -> bool:
"""Check if a member has the required roles."""
if not required_roles: # If no roles are specified, allow everyone
return True
return any(role.id in required_roles for role in member.roles)
def validate_api_key(api_key: str) -> bool:
"""Validate the format of API keys."""
if not api_key:
return False
# OpenAI API keys start with 'sk-' and are 51 characters long
if api_key.startswith('sk-') and len(api_key) == 51:
return True
# Discord tokens are typically longer and don't have a specific prefix
if len(api_key) >= 59: # Discord tokens are typically 59+ characters
return True
return False
# --- Helper Functions ---
def perform_roll(dice_string: str):
"""A simple dice roller that returns a formatted string."""
try:
rolls, modifier, total = roll_dice(dice_string)
mod_str = f" + {modifier}" if modifier > 0 else f" - {abs(modifier)}" if modifier < 0 else ""
return f"Brian rolls `{dice_string}`...\n**Result:** `{rolls}`{mod_str} = **{total}**"
except Exception:
return f"Brian confused by `{dice_string}`. Is not good dice."
# --- Bot Events ---
@bot.event
async def on_ready():
try:
logger.info("=== Bot Starting Up ===")
logger.info(f"Logged in as {bot.user} (ID: {bot.user.id})")
logger.info(f"Connected to {len(bot.guilds)} guilds:")
for guild in bot.guilds:
logger.info(f"- {guild.name} (ID: {guild.id})")
# Load personality
global BRIAN_SYSTEM_PROMPT
try:
logger.info(f"Attempting to load instructions from '{INSTRUCTIONS_FILE_NAME}'")
if not os.path.exists(INSTRUCTIONS_FILE_NAME):
raise FileNotFoundError(f"File not found: {INSTRUCTIONS_FILE_NAME}")
with open(INSTRUCTIONS_FILE_NAME, 'r', encoding='utf-8') as f:
BRIAN_SYSTEM_PROMPT = f.read()
logger.info(f"Successfully loaded instructions from '{INSTRUCTIONS_FILE_NAME}'")
except Exception as e:
logger.error(f"FATAL: Error reading '{INSTRUCTIONS_FILE_NAME}': {str(e)}")
raise
logger.info("=== Bot is ready to receive messages ===")
print(f"Logged in as {bot.user}. Brian is operational.")
except Exception as e:
logger.error(f"FATAL ERROR in on_ready: {str(e)}", exc_info=True)
raise
@bot.event
async def on_message(message):
if message.author.bot:
return
# First, process commands that start with the prefix
await bot.process_commands(message)
# If a command was already processed, don't do anything else
if message.content.startswith(bot.command_prefix):
return
# --- FIX 1: Allow users to use !roll mid-sentence ---
# We check for the command manually if it's not at the start
roll_in_message = re.search(r'!roll\s+((?:\d+d\d+|\d+)(?:[+\-]\d+)?)', message.content, re.IGNORECASE)
if roll_in_message:
dice_string = roll_in_message.group(1)
roll_command = bot.get_command('roll')
if roll_command:
logger.info(f"Found mid-message roll from {message.author.name}: {dice_string}")
# Manually invoke the command from the cog
ctx = await bot.get_context(message)
gameplay_cog = bot.get_cog('Gameplay')
if gameplay_cog:
await roll_command.callback(gameplay_cog, ctx, dice_string=dice_string)
return # Stop processing to avoid treating it as a mention
# --- FIX 2: Handle AI conversations and the @ROLL_DICE action ---
if bot.user.mentioned_in(message):
if mention_limiter.is_rate_limited(message.author.id):
await message.reply("I'm getting too many requests right now. Please wait a moment.")
return
async with message.channel.typing():
# Create default character sheet if it doesn't exist
char_file_path = f"{DATA_DIR}/{message.author.id}.json"
if not os.path.exists(char_file_path):
try:
create_default_character_sheet(message.author.id)
await message.channel.send("I've created a default character sheet for you! Use `!sheet` to view it or `!sheet file` to download it as a template.")
except Exception as e:
logger.error(f"Error creating default character sheet: {str(e)}")
await message.channel.send("I had trouble creating your character sheet. Please try again later.")
return
# Load character sheet for context
system_prompt_content = BRIAN_SYSTEM_PROMPT
if os.path.exists(char_file_path):
with open(char_file_path, 'r', encoding='utf-8') as f:
character_data = json.load(f)
character_json_string = json.dumps(character_data, indent=2)
system_prompt_content += f"\n# YOUR FRIEND'S DATA\nYou are talking to {message.author.display_name}. This is their character sheet. Use it to answer any questions they have about their stats, items, or abilities.\n\n```json\n{character_json_string}\n```"
history_messages = []
async for hist_msg in message.channel.history(limit=10):
role = "user"
if hist_msg.author.bot:
role = "assistant" if hist_msg.author.id == bot.user.id else "user"
history_messages.append({"role": role, "content": f"{hist_msg.author.display_name}: {sanitize_input(hist_msg.content)}"})
history_messages.reverse()
payload = [{"role": "system", "content": system_prompt_content}, *history_messages]
try:
response = openai_client.chat.completions.create(
model=MODEL_NAME, messages=payload, max_tokens=MAX_TOKENS_FOR_RESPONSE, temperature=0.7
)
final_reply_to_send = response.choices[0].message.content
# --- Handle Secret Actions ---
roll_result_str = None
coin_action_str = None
# Check for @COIN
coin_match = re.search(r"@COIN='(.*?)'", final_reply_to_send)
if coin_match:
final_reply_to_send = final_reply_to_send.replace(coin_match.group(0), "").strip()
action = coin_match.group(1).lower().strip()
logger.info(f"AI wants to perform coin action: {action}")
# Get the gameplay cog to call its methods
gameplay_cog = bot.get_cog('Gameplay')
if gameplay_cog:
ctx = await bot.get_context(message)
# We will call the main !coin command and pass the AI's action as the argument
await gameplay_cog.coin.callback(gameplay_cog, ctx, args=action)
# Check for @ROLL
roll_match = re.search(r"@ROLL='(.*?)'", final_reply_to_send)
if roll_match:
final_reply_to_send = final_reply_to_send.replace(roll_match.group(0), "").strip()
dice_to_roll = roll_match.group(1).strip()
logger.info(f"AI wants to roll dice: {dice_to_roll}")
roll_result_str = perform_roll(dice_to_roll)
# Check for @REACT_EMOJI (existing logic)
react_match = re.search(r"@REACT_EMOJI='(.*?)'", final_reply_to_send)
if react_match:
final_reply_to_send = final_reply_to_send.replace(react_match.group(0), "").strip()
emoji_to_react_with = react_match.group(1).strip()
if emoji_to_react_with:
await message.add_reaction(emoji_to_react_with)
# --- Send the final message ---
if final_reply_to_send:
# Strip bot name prefix from the beginning of the response if it exists
final_reply_to_send = re.sub(r'^(?:Brain|Brian):\s*', '', final_reply_to_send, flags=re.IGNORECASE)
await message.reply(final_reply_to_send)
# If there was a roll, send it as a follow-up message
if roll_result_str:
await message.channel.send(roll_result_str)
except Exception as e:
logger.error(f"OpenAI API call failed: {e}", exc_info=True)
await message.reply("I am currently experiencing an issue with my neural interface. Please try again later.")
# --- Bot Commands ---
@bot.command(name='find')
async def find_message(ctx, *, query: str):
"""Searches across specified channels for a query."""
if command_limiter.is_rate_limited(ctx.author.id):
await ctx.send("You're using this command too frequently. Please wait a moment before trying again.")
return
query = sanitize_input(query)
async with ctx.typing():
if not SEARCHABLE_CHANNEL_IDS:
await ctx.send(f"{ctx.author.mention}, the `SEARCHABLE_CHANNEL_IDS` list in the script is empty. The bot owner needs to configure this.")
return
async def search_channel(channel, query_str):
found_in_channel = []
if not channel or not channel.permissions_for(ctx.guild.me).read_message_history:
return []
try:
async for msg in channel.history(limit=200):
if not msg.author.bot and query_str.lower() in msg.content.lower():
found_in_channel.append((channel.name, msg.author.display_name, sanitize_input(msg.content), msg.jump_url))
return found_in_channel
except discord.Forbidden:
return []
channels_to_search = [ctx.guild.get_channel(ch_id) for ch_id in SEARCHABLE_CHANNEL_IDS]
tasks = [search_channel(ch, query) for ch in channels_to_search if ch]
list_of_results = await asyncio.gather(*tasks)
all_found_messages = [msg for sublist in list_of_results for msg in sublist]
if not all_found_messages:
await ctx.send(f"I found no results for **'{query}'** in the archives.")
return
response = f"{ctx.author.mention}, I found these results for **'{query}'**:\n\n"
for i, (ch_name, author, content, url) in enumerate(all_found_messages[:5]):
trimmed_content = content[:150] + "..." if len(content) > 150 else content
response += f"**#{ch_name}** by **{author}**: \"*{trimmed_content}*\" [Jump to Message]({url})\n"
await ctx.send(response)
async def summarize_logic(ctx, channel_name: str):
"""Shared logic for summarizing any channel."""
if command_limiter.is_rate_limited(ctx.author.id):
await ctx.send("You're using this command too frequently. Please wait a moment before trying again.")
return
if not validate_channel_name(channel_name):
await ctx.send("Invalid channel name format. Channel names can only contain lowercase letters, numbers, and hyphens.")
return
target_channel = discord.utils.get(ctx.guild.text_channels, name=channel_name)
if not target_channel:
await ctx.send(f"I could not find the channel `#{channel_name}`.")
return
if not target_channel.permissions_for(ctx.guild.me).read_message_history:
await ctx.send(f"I do not have permission to view the history of `#{channel_name}`.")
return
try:
logger.info(f"Fetching messages from channel {channel_name}")
messages = [msg async for msg in target_channel.history(limit=100)]
content = "\n".join(f"{msg.author.display_name}: {sanitize_input(msg.content)}" for msg in messages if msg.content and not msg.author.bot)
if not content:
await ctx.send(f"`#{channel_name}` has no recent text to summarize.")
return
prompt = f"Summarize the key points and decisions from the following Discord conversation from the '{channel_name}' channel. Be concise and clear:\n\n{content}"
try:
# --- REQUIRED FIX: Using the correct new client for the API call ---
response = openai_client.chat.completions.create(
model=MODEL_NAME,
messages=[{"role": "system", "content": "You are a summarization expert."}, {"role": "user", "content": prompt}],
max_tokens=500,
temperature=0.4
)
summary = response.choices[0].message.content
embed = discord.Embed(title=f"Summary of #{target_channel.name}", description=summary, color=discord.Color.blue())
await ctx.send(embed=embed)
except Exception as e:
logger.error(f"OpenAI API error during summarization: {str(e)}", exc_info=True)
await ctx.send("I had trouble summarizing the channel. Please try again later.")
except Exception as e:
logger.error(f"Summarize command failed for channel '{channel_name}': {str(e)}", exc_info=True)
await ctx.send("An error occurred while trying to summarize. Please try again later.")
@bot.command(name='summarize')
async def summarize_command(ctx, channel: discord.TextChannel):
"""Summarizes the last 100 messages of a given channel."""
async with ctx.typing():
await summarize_logic(ctx, channel.name)
@bot.command(name='recap')
async def recap_command(ctx):
"""Provides a summary of the session notes channel."""
await summarize_logic(ctx, SESSION_NOTES_CHANNEL)
@bot.event
async def on_command_error(ctx, error):
if isinstance(error, commands.CommandNotFound):
return
elif isinstance(error, commands.MissingRequiredArgument):
await ctx.send("Missing required argument. Please check the command usage.")
elif isinstance(error, commands.ChannelNotFound):
await ctx.send("Channel not found. Please check the channel name and try again.")
elif isinstance(error, commands.MissingPermissions):
await ctx.send("You don't have permission to use this command.")
elif isinstance(error, commands.BotMissingPermissions):
await ctx.send("I don't have the necessary permissions to perform this action.")
elif isinstance(error, commands.CommandOnCooldown):
await ctx.send(f"This command is on cooldown. Try again in {error.retry_after:.2f} seconds.")
else:
logger.error(f"An unexpected command error occurred: {error}", exc_info=True)
await ctx.send("An unexpected error occurred. Please try again later.")
async def main():
# Load Cogs
if not os.path.exists('cogs'):
os.makedirs('cogs')
for filename in os.listdir('./cogs'):
if filename.endswith('.py'):
try:
await bot.load_extension(f'cogs.{filename[:-3]}')
logger.info(f"Successfully loaded cog: {filename}")
except Exception as e:
logger.error(f"Failed to load cog {filename}: {e}")
await bot.start(DISCORD_TOKEN)
if __name__ == "__main__":
# --- TEMPORARY DEBUGGING ---
print("--- ENVIRONMENT SANITY CHECK ---")
print(f"DISCORD_TOKEN Loaded: {bool(DISCORD_TOKEN)}")
print(f"OPENAI_API_KEY Loaded: {bool(OPENAI_API_KEY)}")
print(f"DATA_DIR Value: {os.getenv('DATA_DIR', 'Not Set')}")
print("------------------------------")
# --- END DEBUGGING ---
try:
# Create data directory if it doesn't exist
if not os.path.exists(DATA_DIR):
os.makedirs(DATA_DIR)
logger.info(f"Created data directory: {DATA_DIR}")
# --- ADD THIS PART ---
# Start the Flask server in a background thread to keep the bot alive
flask_thread = Thread(target=run_flask)
flask_thread.daemon = True
flask_thread.start()
logger.info("Health check server started in background thread.")
# --- END OF ADDED PART ---
logger.info("Starting bot...")
asyncio.run(main())
except Exception as e:
logger.critical(f"FATAL ERROR during bot startup: {str(e)}", exc_info=True)
exit(1)