|
| 1 | +"""Application factory for Oneiro Discord bot.""" |
| 2 | + |
| 3 | +import os |
| 4 | +from pathlib import Path |
| 5 | +from typing import Any |
| 6 | + |
| 7 | +import discord |
| 8 | + |
| 9 | +from oneiro.civitai import CivitaiClient |
| 10 | +from oneiro.config import Config |
| 11 | +from oneiro.discord.commands import register_commands |
| 12 | +from oneiro.discord.handlers import create_config_change_handler, handle_reaction_delete |
| 13 | +from oneiro.filters import ContentFilter |
| 14 | +from oneiro.lora_detector import AutoLoraDetector, create_detector_from_config |
| 15 | +from oneiro.pipelines import PipelineManager |
| 16 | +from oneiro.queue import GenerationQueue |
| 17 | + |
| 18 | + |
| 19 | +class OneiroBot(discord.Bot): |
| 20 | + """Oneiro Discord bot with typed state management.""" |
| 21 | + |
| 22 | + def __init__(self, *args: Any, **kwargs: Any) -> None: |
| 23 | + super().__init__(*args, **kwargs) |
| 24 | + self.config: Config | None = None |
| 25 | + self.pipeline_manager: PipelineManager | None = None |
| 26 | + self.generation_queue: GenerationQueue | None = None |
| 27 | + self.content_filter: ContentFilter | None = None |
| 28 | + self.civitai_client: CivitaiClient | None = None |
| 29 | + self.lora_detector: AutoLoraDetector | None = None |
| 30 | + |
| 31 | + |
| 32 | +def create_app() -> OneiroBot: |
| 33 | + """Create and wire the Discord bot with all services. |
| 34 | +
|
| 35 | + Returns: |
| 36 | + Configured OneiroBot instance ready to run. |
| 37 | + """ |
| 38 | + activity = discord.Activity( |
| 39 | + name="Dreaming...", |
| 40 | + type=discord.ActivityType.custom, |
| 41 | + ) |
| 42 | + |
| 43 | + bot = OneiroBot(activity=activity) |
| 44 | + |
| 45 | + # Register slash commands |
| 46 | + register_commands(bot) |
| 47 | + |
| 48 | + @bot.event |
| 49 | + async def on_ready() -> None: |
| 50 | + """Initialize config, pipeline and queue when bot connects.""" |
| 51 | + print(f"{bot.user} is online!") |
| 52 | + |
| 53 | + # Load configuration |
| 54 | + base_config_path = Path(os.environ.get("CONFIG_PATH", "config.toml")) |
| 55 | + overlay_config_path = os.environ.get("CONFIG_OVERLAY_PATH") |
| 56 | + state_path = os.environ.get("STATE_PATH") |
| 57 | + |
| 58 | + bot.config = Config( |
| 59 | + base_path=base_config_path, |
| 60 | + overlay_path=Path(overlay_config_path) if overlay_config_path else None, |
| 61 | + state_path=Path(state_path) if state_path else None, |
| 62 | + ) |
| 63 | + bot.config.load() |
| 64 | + print(f"Config loaded from {base_config_path}") |
| 65 | + |
| 66 | + # Initialize Civitai client |
| 67 | + bot.civitai_client = CivitaiClient.from_config(bot.config) |
| 68 | + print("Civitai client initialized") |
| 69 | + |
| 70 | + # Initialize content filter |
| 71 | + bot.content_filter = ContentFilter(bot.config) |
| 72 | + print("Content filter initialized") |
| 73 | + |
| 74 | + # Initialize LoRA auto-detector |
| 75 | + bot.lora_detector = create_detector_from_config(bot.config.data) |
| 76 | + print("LoRA auto-detector initialized") |
| 77 | + |
| 78 | + # Initialize pipeline manager with config |
| 79 | + bot.pipeline_manager = PipelineManager(bot.config) |
| 80 | + bot.pipeline_manager.set_civitai_client(bot.civitai_client) |
| 81 | + print("Loading default model...") |
| 82 | + await bot.pipeline_manager.load_model() |
| 83 | + print(f"Model loaded: {bot.pipeline_manager.current_model}") |
| 84 | + |
| 85 | + # Initialize queue with config values |
| 86 | + max_global = bot.config.get("queue", "max_global", default=100) |
| 87 | + max_per_user = bot.config.get("queue", "max_per_user", default=20) |
| 88 | + bot.generation_queue = GenerationQueue(max_global=max_global, max_per_user=max_per_user) |
| 89 | + await bot.generation_queue.start(bot.pipeline_manager) |
| 90 | + print(f"Queue started: {max_global} global, {max_per_user} per user") |
| 91 | + |
| 92 | + # Register config change callback |
| 93 | + bot.config.on_change(create_config_change_handler(bot)) |
| 94 | + |
| 95 | + # Start config file watching |
| 96 | + await bot.config.start_watching() |
| 97 | + |
| 98 | + # Sync slash commands to all guilds for instant availability |
| 99 | + # (global commands can take up to 1 hour to propagate) |
| 100 | + if bot.guilds: |
| 101 | + guild_ids = [g.id for g in bot.guilds] |
| 102 | + await bot.sync_commands(guild_ids=guild_ids) |
| 103 | + print(f"Commands synced to {len(guild_ids)} guild(s): {[g.name for g in bot.guilds]}") |
| 104 | + |
| 105 | + print("Ready to generate images!") |
| 106 | + |
| 107 | + @bot.event |
| 108 | + async def on_raw_reaction_add(payload: discord.RawReactionActionEvent) -> None: |
| 109 | + """Handle ❌ reaction to delete generated images.""" |
| 110 | + await handle_reaction_delete(bot, payload) |
| 111 | + |
| 112 | + return bot |
0 commit comments