Skip to content

Commit e7971da

Browse files
committed
docs: add scheduler configuration documentation
Document available scheduler options in config.toml: - dpm++_karras, dpm++, euler_a, euler, heun, ddim, default - Explain defaults by model type - Add example showing scheduler config usage
1 parent 69a54fa commit e7971da

File tree

2 files changed

+238
-3
lines changed

2 files changed

+238
-3
lines changed

config.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,21 @@ verify_hashes = true
143143
# - SD 3, SD 3.5 Medium/Large (1024x1024)
144144
# - PixArt Alpha/Sigma, Kolors, Hunyuan DiT, AuraFlow
145145
#
146+
# SCHEDULER OPTIONS:
147+
# The scheduler controls the denoising process. Available options:
148+
# - dpm++_karras: DPM++ with Karras sigmas (recommended for SD/SDXL/Pony)
149+
# - dpm++: DPM++ without Karras sigmas
150+
# - euler_a: Euler Ancestral (good variety, slightly random)
151+
# - euler: Euler (deterministic)
152+
# - heun: Heun (higher quality, 2x slower)
153+
# - ddim: DDIM (classic, deterministic)
154+
# - default: Keep the model's built-in scheduler
155+
#
156+
# Defaults by model type:
157+
# - SD/SDXL/Pony/Illustrious: dpm++_karras
158+
# - Turbo/Lightning/LCM/Hyper: default (keep optimized scheduler)
159+
# - Flux/SD3/PixArt/etc: default (flow-based models)
160+
#
146161
# Example 1: CivitAI checkpoint by model ID (downloads automatically)
147162
# [models.civitai-example]
148163
# type = "civitai"
@@ -181,3 +196,9 @@ verify_hashes = true
181196
# steps = 30
182197
# guidance_scale = 5.0
183198
# supports_negative_prompt = true
199+
#
200+
# Example 6: CivitAI checkpoint with custom scheduler
201+
# [models.realistic-xl]
202+
# type = "civitai"
203+
# civitai_model_id = 12345
204+
# scheduler = "euler_a" # Override default scheduler

tests/test_civitai_checkpoint.py

Lines changed: 217 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from oneiro.pipelines.civitai_checkpoint import (
88
CIVITAI_BASE_MODEL_PIPELINE_MAP,
99
DEFAULT_PIPELINE_CONFIG,
10+
SCHEDULER_CHOICES,
11+
SCHEDULER_MAP,
1012
CivitaiCheckpointPipeline,
1113
PipelineConfig,
1214
get_diffusers_pipeline_class,
@@ -258,8 +260,9 @@ def test_load_file_not_found(self, tmp_path):
258260
with pytest.raises(FileNotFoundError, match="Checkpoint not found"):
259261
pipeline.load({"checkpoint_path": str(tmp_path / "nonexistent.safetensors")})
260262

263+
@patch.object(CivitaiCheckpointPipeline, "configure_scheduler")
261264
@patch("oneiro.pipelines.civitai_checkpoint.get_diffusers_pipeline_class")
262-
def test_load_with_base_model_override(self, mock_get_class, tmp_path):
265+
def test_load_with_base_model_override(self, mock_get_class, mock_config_sched, tmp_path):
263266
"""load() uses base_model override from config."""
264267
# Create dummy checkpoint file
265268
checkpoint = tmp_path / "model.safetensors"
@@ -307,8 +310,9 @@ def test_load_with_pipeline_class_override(self, mock_get_class, tmp_path):
307310
assert pipeline._pipeline_config.default_guidance_scale == 5.0
308311
mock_get_class.assert_called_once_with("CustomPipeline")
309312

313+
@patch.object(CivitaiCheckpointPipeline, "configure_scheduler")
310314
@patch("oneiro.pipelines.civitai_checkpoint.get_diffusers_pipeline_class")
311-
def test_load_enables_cpu_offload(self, mock_get_class, tmp_path):
315+
def test_load_enables_cpu_offload(self, mock_get_class, mock_config_sched, tmp_path):
312316
"""load() enables CPU offload when configured."""
313317
checkpoint = tmp_path / "model.safetensors"
314318
checkpoint.write_bytes(b"dummy")
@@ -329,8 +333,9 @@ def test_load_enables_cpu_offload(self, mock_get_class, tmp_path):
329333

330334
mock_pipe.enable_model_cpu_offload.assert_called_once()
331335

336+
@patch.object(CivitaiCheckpointPipeline, "configure_scheduler")
332337
@patch("oneiro.pipelines.civitai_checkpoint.get_diffusers_pipeline_class")
333-
def test_load_enables_vae_optimizations(self, mock_get_class, tmp_path):
338+
def test_load_enables_vae_optimizations(self, mock_get_class, mock_config_sched, tmp_path):
334339
"""load() enables VAE tiling and slicing."""
335340
checkpoint = tmp_path / "model.safetensors"
336341
checkpoint.write_bytes(b"dummy")
@@ -620,3 +625,212 @@ def test_generate_handles_img2img(self):
620625
assert call_kwargs["strength"] == 0.5
621626
assert "width" not in call_kwargs
622627
assert "height" not in call_kwargs
628+
629+
def test_generate_with_scheduler_override(self):
630+
pipeline = CivitaiCheckpointPipeline()
631+
pipeline._pipeline_config = PipelineConfig(
632+
pipeline_class="StableDiffusionXLPipeline",
633+
default_scheduler="dpm++_karras",
634+
)
635+
636+
mock_pipe = MagicMock()
637+
mock_scheduler = MagicMock()
638+
mock_scheduler.config = {}
639+
mock_pipe.scheduler = mock_scheduler
640+
mock_image = MagicMock()
641+
mock_image.width = 1024
642+
mock_image.height = 1024
643+
mock_pipe.return_value.images = [mock_image]
644+
pipeline.pipe = mock_pipe
645+
646+
mock_euler = MagicMock()
647+
with (
648+
patch("oneiro.pipelines.civitai_checkpoint.torch"),
649+
patch("diffusers.EulerAncestralDiscreteScheduler", mock_euler),
650+
):
651+
pipeline.generate("test prompt", scheduler="euler_a")
652+
653+
mock_euler.from_config.assert_called_once()
654+
655+
656+
class TestSchedulerMap:
657+
def test_scheduler_choices_matches_map_keys(self):
658+
assert set(SCHEDULER_CHOICES) == set(SCHEDULER_MAP.keys())
659+
660+
def test_default_entry_has_none_class(self):
661+
class_name, kwargs = SCHEDULER_MAP["default"]
662+
assert class_name is None
663+
assert kwargs == {}
664+
665+
def test_dpm_karras_entry(self):
666+
class_name, kwargs = SCHEDULER_MAP["dpm++_karras"]
667+
assert class_name == "DPMSolverMultistepScheduler"
668+
assert kwargs["algorithm_type"] == "sde-dpmsolver++"
669+
assert kwargs["use_karras_sigmas"] is True
670+
671+
def test_dpm_entry(self):
672+
class_name, kwargs = SCHEDULER_MAP["dpm++"]
673+
assert class_name == "DPMSolverMultistepScheduler"
674+
assert kwargs["algorithm_type"] == "sde-dpmsolver++"
675+
assert kwargs["use_karras_sigmas"] is False
676+
677+
def test_euler_a_entry(self):
678+
class_name, kwargs = SCHEDULER_MAP["euler_a"]
679+
assert class_name == "EulerAncestralDiscreteScheduler"
680+
assert kwargs == {}
681+
682+
def test_euler_entry(self):
683+
class_name, kwargs = SCHEDULER_MAP["euler"]
684+
assert class_name == "EulerDiscreteScheduler"
685+
assert kwargs == {}
686+
687+
def test_heun_entry(self):
688+
class_name, kwargs = SCHEDULER_MAP["heun"]
689+
assert class_name == "HeunDiscreteScheduler"
690+
assert kwargs == {}
691+
692+
def test_ddim_entry(self):
693+
class_name, kwargs = SCHEDULER_MAP["ddim"]
694+
assert class_name == "DDIMScheduler"
695+
assert kwargs == {}
696+
697+
698+
class TestDefaultSchedulers:
699+
def test_sd15_has_dpm_karras(self):
700+
config = CIVITAI_BASE_MODEL_PIPELINE_MAP["SD 1.5"]
701+
assert config.default_scheduler == "dpm++_karras"
702+
703+
def test_sdxl_has_dpm_karras(self):
704+
config = CIVITAI_BASE_MODEL_PIPELINE_MAP["SDXL 1.0"]
705+
assert config.default_scheduler == "dpm++_karras"
706+
707+
def test_pony_has_dpm_karras(self):
708+
config = CIVITAI_BASE_MODEL_PIPELINE_MAP["Pony"]
709+
assert config.default_scheduler == "dpm++_karras"
710+
711+
def test_illustrious_has_dpm_karras(self):
712+
config = CIVITAI_BASE_MODEL_PIPELINE_MAP["Illustrious"]
713+
assert config.default_scheduler == "dpm++_karras"
714+
715+
def test_sdxl_turbo_keeps_default(self):
716+
config = CIVITAI_BASE_MODEL_PIPELINE_MAP["SDXL Turbo"]
717+
assert config.default_scheduler == "default"
718+
719+
def test_sdxl_lightning_keeps_default(self):
720+
config = CIVITAI_BASE_MODEL_PIPELINE_MAP["SDXL Lightning"]
721+
assert config.default_scheduler == "default"
722+
723+
def test_flux_keeps_default(self):
724+
config = CIVITAI_BASE_MODEL_PIPELINE_MAP["Flux.1 Dev"]
725+
assert config.default_scheduler == "default"
726+
727+
def test_sd3_keeps_default(self):
728+
config = CIVITAI_BASE_MODEL_PIPELINE_MAP["SD 3"]
729+
assert config.default_scheduler == "default"
730+
731+
732+
class TestConfigureScheduler:
733+
def test_configure_scheduler_with_none_uses_pipeline_default(self):
734+
pipeline = CivitaiCheckpointPipeline()
735+
pipeline._pipeline_config = PipelineConfig(
736+
pipeline_class="StableDiffusionXLPipeline",
737+
default_scheduler="dpm++_karras",
738+
)
739+
mock_pipe = MagicMock()
740+
mock_scheduler = MagicMock()
741+
mock_scheduler.config = {}
742+
mock_pipe.scheduler = mock_scheduler
743+
pipeline.pipe = mock_pipe
744+
745+
mock_dpm = MagicMock()
746+
with patch("diffusers.DPMSolverMultistepScheduler", mock_dpm):
747+
pipeline.configure_scheduler(None)
748+
749+
mock_dpm.from_config.assert_called_once()
750+
751+
def test_configure_scheduler_default_string_no_change(self):
752+
pipeline = CivitaiCheckpointPipeline()
753+
pipeline._pipeline_config = PipelineConfig(
754+
pipeline_class="StableDiffusionXLPipeline",
755+
default_scheduler="default",
756+
)
757+
mock_pipe = MagicMock()
758+
original_scheduler = MagicMock()
759+
mock_pipe.scheduler = original_scheduler
760+
pipeline.pipe = mock_pipe
761+
762+
pipeline.configure_scheduler("default")
763+
764+
assert mock_pipe.scheduler is original_scheduler
765+
766+
def test_configure_scheduler_unknown_warns(self, capsys):
767+
pipeline = CivitaiCheckpointPipeline()
768+
pipeline._pipeline_config = PipelineConfig(
769+
pipeline_class="StableDiffusionXLPipeline",
770+
)
771+
mock_pipe = MagicMock()
772+
original_scheduler = MagicMock()
773+
mock_pipe.scheduler = original_scheduler
774+
pipeline.pipe = mock_pipe
775+
776+
pipeline.configure_scheduler("unknown_scheduler")
777+
778+
captured = capsys.readouterr()
779+
assert "Unknown scheduler" in captured.out
780+
assert mock_pipe.scheduler is original_scheduler
781+
782+
def test_configure_scheduler_euler_a(self):
783+
pipeline = CivitaiCheckpointPipeline()
784+
pipeline._pipeline_config = PipelineConfig(
785+
pipeline_class="StableDiffusionXLPipeline",
786+
)
787+
mock_pipe = MagicMock()
788+
mock_scheduler = MagicMock()
789+
mock_scheduler.config = {}
790+
mock_pipe.scheduler = mock_scheduler
791+
pipeline.pipe = mock_pipe
792+
793+
mock_euler = MagicMock()
794+
with patch("diffusers.EulerAncestralDiscreteScheduler", mock_euler):
795+
pipeline.configure_scheduler("euler_a")
796+
797+
mock_euler.from_config.assert_called_once_with({})
798+
799+
def test_configure_scheduler_with_kwargs(self):
800+
pipeline = CivitaiCheckpointPipeline()
801+
pipeline._pipeline_config = PipelineConfig(
802+
pipeline_class="StableDiffusionXLPipeline",
803+
)
804+
mock_pipe = MagicMock()
805+
mock_scheduler = MagicMock()
806+
mock_scheduler.config = {"some": "config"}
807+
mock_pipe.scheduler = mock_scheduler
808+
pipeline.pipe = mock_pipe
809+
810+
mock_dpm = MagicMock()
811+
with patch("diffusers.DPMSolverMultistepScheduler", mock_dpm):
812+
pipeline.configure_scheduler("dpm++_karras")
813+
814+
mock_dpm.from_config.assert_called_once_with(
815+
{"some": "config"},
816+
algorithm_type="sde-dpmsolver++",
817+
use_karras_sigmas=True,
818+
)
819+
820+
def test_configure_scheduler_skips_redundant_reconfiguration(self):
821+
pipeline = CivitaiCheckpointPipeline()
822+
pipeline._pipeline_config = PipelineConfig(
823+
pipeline_class="StableDiffusionXLPipeline",
824+
)
825+
mock_pipe = MagicMock()
826+
mock_scheduler = MagicMock()
827+
mock_scheduler.config = {}
828+
mock_pipe.scheduler = mock_scheduler
829+
pipeline.pipe = mock_pipe
830+
831+
mock_euler = MagicMock()
832+
with patch("diffusers.EulerAncestralDiscreteScheduler", mock_euler):
833+
pipeline.configure_scheduler("euler_a")
834+
pipeline.configure_scheduler("euler_a")
835+
836+
mock_euler.from_config.assert_called_once()

0 commit comments

Comments
 (0)