Skip to content

Commit 587071f

Browse files
authored
Merge pull request #52 from jkoelker/feat/issue-51-expose-generation-params
feat: expose steps and guidance_scale in /dream and /model commands
2 parents 024266b + 6dc63f6 commit 587071f

File tree

2 files changed

+336
-15
lines changed

2 files changed

+336
-15
lines changed

src/oneiro/bot.py

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@
3434
MIN_LORA_WEIGHT = -2.0
3535
MAX_LORA_WEIGHT = 2.0
3636

37+
# Steps validation limits (for /dream and /model commands)
38+
MIN_STEPS = 1
39+
MAX_STEPS = 100
40+
41+
# Guidance scale validation limits (for /dream and /model commands)
42+
MIN_GUIDANCE_SCALE = 0.0
43+
MAX_GUIDANCE_SCALE = 15.0
44+
3745

3846
def validate_lora_weight(weight: float, lora_name: str) -> None:
3947
"""Validate that a LoRA weight is within acceptable bounds.
@@ -346,6 +354,22 @@ async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
346354
description="Random seed (-1 for random)",
347355
required=False,
348356
)
357+
@option(
358+
"steps",
359+
int,
360+
description="Number of inference steps (default: model-specific)",
361+
required=False,
362+
min_value=MIN_STEPS,
363+
max_value=MAX_STEPS,
364+
)
365+
@option(
366+
"guidance_scale",
367+
float,
368+
description="CFG scale - prompt adherence (default: model-specific)",
369+
required=False,
370+
min_value=MIN_GUIDANCE_SCALE,
371+
max_value=MAX_GUIDANCE_SCALE,
372+
)
349373
@option(
350374
"lora",
351375
str,
@@ -372,6 +396,8 @@ async def dream(
372396
width: int = 1024,
373397
height: int = 1024,
374398
seed: int = -1,
399+
steps: int | None = None,
400+
guidance_scale: float | None = None,
375401
lora: str | None = None,
376402
scheduler: str | None = None,
377403
):
@@ -408,12 +434,26 @@ async def dream(
408434
current_model = pipeline_manager.current_model or "zimage-turbo"
409435
model_config = config.get("models", current_model, default={}) if config else {}
410436
pipeline_type = model_config.get("type") if model_config else None
411-
steps = model_config.get("steps", 9)
412-
guidance_scale = model_config.get("guidance_scale", 0.0)
437+
438+
# Get model config defaults
439+
model_steps = model_config.get("steps", 9)
440+
model_guidance = model_config.get("guidance_scale", 0.0)
413441

414442
# Handle Qwen's true_cfg_scale
415443
if model_config.get("true_cfg_scale"):
416-
guidance_scale = model_config.get("true_cfg_scale", 4.0)
444+
model_guidance = model_config["true_cfg_scale"]
445+
446+
# Check for model-specific overrides set via /model command
447+
model_overrides = config.get("model_overrides", current_model, default={}) if config else {}
448+
if model_overrides:
449+
if "steps" in model_overrides:
450+
model_steps = model_overrides["steps"]
451+
if "guidance_scale" in model_overrides:
452+
model_guidance = model_overrides["guidance_scale"]
453+
454+
# User-provided values take priority over model defaults
455+
actual_steps = steps if steps is not None else model_steps
456+
actual_guidance = guidance_scale if guidance_scale is not None else model_guidance
417457

418458
# Resolve LoRAs if specified
419459
lora_configs: list[LoraConfig] = []
@@ -507,8 +547,8 @@ async def dream(
507547
"width": width,
508548
"height": height,
509549
"seed": seed,
510-
"steps": steps,
511-
"guidance_scale": guidance_scale,
550+
"steps": actual_steps,
551+
"guidance_scale": actual_guidance,
512552
}
513553

514554
if scheduler:
@@ -575,6 +615,8 @@ async def on_complete(result: GenerationResult | Exception) -> None:
575615
embed.add_field(name="Seed", value=str(result.seed), inline=True)
576616
embed.add_field(name="Time", value=f"{elapsed:.1f}s", inline=True)
577617
embed.add_field(name="Model", value=f"`{current_model}`", inline=True)
618+
embed.add_field(name="Steps", value=str(result.steps), inline=True)
619+
embed.add_field(name="CFG", value=f"{result.guidance_scale:.1f}", inline=True)
578620
if is_img2img:
579621
embed.add_field(name="Strength", value=f"{strength:.2f}", inline=True)
580622
if lora_configs:
@@ -667,10 +709,28 @@ async def queue_status(ctx: discord.ApplicationContext):
667709
required=False,
668710
choices=SCHEDULER_CHOICES,
669711
)
712+
@option(
713+
"steps",
714+
int,
715+
description="Override default steps for this model",
716+
required=False,
717+
min_value=MIN_STEPS,
718+
max_value=MAX_STEPS,
719+
)
720+
@option(
721+
"guidance_scale",
722+
float,
723+
description="Override default CFG scale for this model",
724+
required=False,
725+
min_value=MIN_GUIDANCE_SCALE,
726+
max_value=MAX_GUIDANCE_SCALE,
727+
)
670728
async def model_command(
671729
ctx: discord.ApplicationContext,
672730
model: str,
673731
scheduler: str | None = None,
732+
steps: int | None = None,
733+
guidance_scale: float | None = None,
674734
):
675735
"""Switch the active diffusion model and optionally override its scheduler.
676736
@@ -698,14 +758,13 @@ async def model_command(
698758

699759
# Check if already loaded
700760
if pipeline_manager.current_model == model:
761+
# Model is already active - handle overrides only
762+
overrides_applied = []
763+
701764
if scheduler and pipeline_manager.pipeline is not None:
702765
if isinstance(pipeline_manager.pipeline, CivitaiCheckpointPipeline):
703766
pipeline_manager.pipeline.configure_scheduler(scheduler)
704-
await ctx.respond(
705-
f"✅ Model `{model}` already active, scheduler set to `{scheduler}`.",
706-
ephemeral=True,
707-
)
708-
return
767+
overrides_applied.append(f"scheduler=`{scheduler}`")
709768
else:
710769
await ctx.respond(
711770
f"✅ Model `{model}` is already active.\n"
@@ -714,10 +773,25 @@ async def model_command(
714773
)
715774
return
716775

717-
await ctx.respond(
718-
f"✅ Model `{model}` is already active.",
719-
ephemeral=True,
720-
)
776+
# Save steps/guidance_scale overrides to state
777+
if config.state_path:
778+
if steps is not None:
779+
config.set("model_overrides", model, "steps", value=steps)
780+
overrides_applied.append(f"steps={steps}")
781+
if guidance_scale is not None:
782+
config.set("model_overrides", model, "guidance_scale", value=guidance_scale)
783+
overrides_applied.append(f"guidance_scale={guidance_scale}")
784+
785+
if overrides_applied:
786+
await ctx.respond(
787+
f"✅ Model `{model}` already active. Set: {', '.join(overrides_applied)}",
788+
ephemeral=True,
789+
)
790+
else:
791+
await ctx.respond(
792+
f"✅ Model `{model}` is already active.",
793+
ephemeral=True,
794+
)
721795
return
722796

723797
# Defer for model loading (can be slow)
@@ -733,10 +807,22 @@ async def model_command(
733807

734808
if config.state_path:
735809
config.set("defaults", "model", value=model)
810+
# Save steps/guidance_scale overrides to state
811+
if steps is not None:
812+
config.set("model_overrides", model, "steps", value=steps)
813+
if guidance_scale is not None:
814+
config.set("model_overrides", model, "guidance_scale", value=guidance_scale)
736815

737816
msg = f"✅ Switched to model `{model}`"
817+
overrides = []
738818
if scheduler:
739-
msg += f" with scheduler `{scheduler}`"
819+
overrides.append(f"scheduler=`{scheduler}`")
820+
if steps is not None:
821+
overrides.append(f"steps={steps}")
822+
if guidance_scale is not None:
823+
overrides.append(f"guidance_scale={guidance_scale}")
824+
if overrides:
825+
msg += f" with {', '.join(overrides)}"
740826
await loading_msg.edit(content=msg)
741827
except Exception as e:
742828
await ctx.followup.send(f"❌ Failed to load model: {e}", ephemeral=True)

0 commit comments

Comments
 (0)