Skip to content

Commit 024266b

Browse files
authored
Merge pull request #50 from jkoelker/jk/scheduler
feat: add scheduler configuration for civitai checkpoints
2 parents 99870a1 + e7971da commit 024266b

File tree

5 files changed

+447
-25
lines changed

5 files changed

+447
-25
lines changed

config.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,21 @@ verify_hashes = true
143143
# - SD 3, SD 3.5 Medium/Large (1024x1024)
144144
# - PixArt Alpha/Sigma, Kolors, Hunyuan DiT, AuraFlow
145145
#
146+
# SCHEDULER OPTIONS:
147+
# The scheduler controls the denoising process. Available options:
148+
# - dpm++_karras: DPM++ with Karras sigmas (recommended for SD/SDXL/Pony)
149+
# - dpm++: DPM++ without Karras sigmas
150+
# - euler_a: Euler Ancestral (good variety, slightly random)
151+
# - euler: Euler (deterministic)
152+
# - heun: Heun (higher quality, 2x slower)
153+
# - ddim: DDIM (classic, deterministic)
154+
# - default: Keep the model's built-in scheduler
155+
#
156+
# Defaults by model type:
157+
# - SD/SDXL/Pony/Illustrious: dpm++_karras
158+
# - Turbo/Lightning/LCM/Hyper: default (keep optimized scheduler)
159+
# - Flux/SD3/PixArt/etc: default (flow-based models)
160+
#
146161
# Example 1: CivitAI checkpoint by model ID (downloads automatically)
147162
# [models.civitai-example]
148163
# type = "civitai"
@@ -181,3 +196,9 @@ verify_hashes = true
181196
# steps = 30
182197
# guidance_scale = 5.0
183198
# supports_negative_prompt = true
199+
#
200+
# Example 6: CivitAI checkpoint with custom scheduler
201+
# [models.realistic-xl]
202+
# type = "civitai"
203+
# civitai_model_id = 12345
204+
# scheduler = "euler_a" # Override default scheduler

src/oneiro/bot.py

Lines changed: 60 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,14 @@
1212
from oneiro.civitai import CivitaiClient, CivitaiError
1313
from oneiro.config import Config
1414
from oneiro.filters import ContentFilter
15-
from oneiro.pipelines import GenerationResult, LoraConfig, LoraSource, PipelineManager
15+
from oneiro.pipelines import (
16+
SCHEDULER_CHOICES,
17+
GenerationResult,
18+
LoraConfig,
19+
LoraSource,
20+
PipelineManager,
21+
)
22+
from oneiro.pipelines.civitai_checkpoint import CivitaiCheckpointPipeline
1623
from oneiro.pipelines.lora import is_lora_compatible, parse_civitai_url
1724
from oneiro.queue import GenerationQueue, QueueStatus
1825

@@ -349,6 +356,13 @@ async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
349356
required=False,
350357
autocomplete=get_lora_choices,
351358
)
359+
@option(
360+
"scheduler",
361+
str,
362+
description="Scheduler for denoising (dpm++_karras, euler_a, etc.)",
363+
required=False,
364+
choices=SCHEDULER_CHOICES,
365+
)
352366
async def dream(
353367
ctx: discord.ApplicationContext,
354368
prompt: str,
@@ -359,6 +373,7 @@ async def dream(
359373
height: int = 1024,
360374
seed: int = -1,
361375
lora: str | None = None,
376+
scheduler: str | None = None,
362377
):
363378
"""Generate an image from a text prompt."""
364379
global config, pipeline_manager, generation_queue, content_filter, civitai_client
@@ -486,7 +501,6 @@ async def dream(
486501
await ctx.followup.send(f"❌ Invalid LoRA specification: {e}", ephemeral=True)
487502
return
488503

489-
# Build generation request
490504
request: dict[str, Any] = {
491505
"prompt": prompt,
492506
"negative_prompt": negative_prompt,
@@ -497,7 +511,9 @@ async def dream(
497511
"guidance_scale": guidance_scale,
498512
}
499513

500-
# Add LoRA configs to request if any
514+
if scheduler:
515+
request["scheduler"] = scheduler
516+
501517
if lora_configs:
502518
request["loras"] = lora_configs
503519

@@ -566,6 +582,8 @@ async def on_complete(result: GenerationResult | Exception) -> None:
566582
if len(lora_display) > 1024:
567583
lora_display = lora_display[:1021] + "..."
568584
embed.add_field(name="LoRA", value=lora_display, inline=True)
585+
if scheduler:
586+
embed.add_field(name="Scheduler", value=f"`{scheduler}`", inline=True)
569587
embed.set_image(url="attachment://dream.png")
570588
embed.set_footer(
571589
text=f"Requested by {ctx.author.name} • React ❌ to delete",
@@ -642,11 +660,25 @@ async def queue_status(ctx: discord.ApplicationContext):
642660
required=True,
643661
autocomplete=get_model_choices,
644662
)
663+
@option(
664+
"scheduler",
665+
str,
666+
description="Override default scheduler for this model",
667+
required=False,
668+
choices=SCHEDULER_CHOICES,
669+
)
645670
async def model_command(
646671
ctx: discord.ApplicationContext,
647672
model: str,
673+
scheduler: str | None = None,
648674
):
649-
"""Switch the active generation model."""
675+
"""Switch the active diffusion model and optionally override its scheduler.
676+
677+
This command changes which configured model is used for image generation.
678+
If a scheduler is provided and supported by the loaded pipeline, it will be
679+
applied after the model is loaded. The selected model may also be stored as
680+
the default in the persistent configuration if a state path is configured.
681+
"""
650682
global config, pipeline_manager
651683

652684
if pipeline_manager is None or config is None:
@@ -666,6 +698,22 @@ async def model_command(
666698

667699
# Check if already loaded
668700
if pipeline_manager.current_model == model:
701+
if scheduler and pipeline_manager.pipeline is not None:
702+
if isinstance(pipeline_manager.pipeline, CivitaiCheckpointPipeline):
703+
pipeline_manager.pipeline.configure_scheduler(scheduler)
704+
await ctx.respond(
705+
f"✅ Model `{model}` already active, scheduler set to `{scheduler}`.",
706+
ephemeral=True,
707+
)
708+
return
709+
else:
710+
await ctx.respond(
711+
f"✅ Model `{model}` is already active.\n"
712+
f"⚠️ Scheduler override is not supported for this pipeline type.",
713+
ephemeral=True,
714+
)
715+
return
716+
669717
await ctx.respond(
670718
f"✅ Model `{model}` is already active.",
671719
ephemeral=True,
@@ -679,11 +727,17 @@ async def model_command(
679727
loading_msg = await ctx.followup.send(f"⏳ Loading model `{model}`...")
680728
await pipeline_manager.load_model(model)
681729

682-
# Persist model choice for next restart
730+
if scheduler and pipeline_manager.pipeline is not None:
731+
if isinstance(pipeline_manager.pipeline, CivitaiCheckpointPipeline):
732+
pipeline_manager.pipeline.configure_scheduler(scheduler)
733+
683734
if config.state_path:
684735
config.set("defaults", "model", value=model)
685736

686-
await loading_msg.edit(content=f"✅ Switched to model `{model}`")
737+
msg = f"✅ Switched to model `{model}`"
738+
if scheduler:
739+
msg += f" with scheduler `{scheduler}`"
740+
await loading_msg.edit(content=msg)
687741
except Exception as e:
688742
await ctx.followup.send(f"❌ Failed to load model: {e}", ephemeral=True)
689743

src/oneiro/pipelines/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from oneiro.pipelines.base import BasePipeline, GenerationResult
1010
from oneiro.pipelines.civitai_checkpoint import (
1111
CIVITAI_BASE_MODEL_PIPELINE_MAP,
12+
SCHEDULER_CHOICES,
13+
SCHEDULER_MAP,
1214
CivitaiCheckpointPipeline,
1315
PipelineConfig,
1416
get_pipeline_config_for_base_model,
@@ -42,6 +44,8 @@
4244
"CivitaiCheckpointPipeline",
4345
"PipelineConfig",
4446
"CIVITAI_BASE_MODEL_PIPELINE_MAP",
47+
"SCHEDULER_CHOICES",
48+
"SCHEDULER_MAP",
4549
"get_pipeline_config_for_base_model",
4650
"LoraConfig",
4751
"LoraSource",

0 commit comments

Comments
 (0)