77from oneiro .pipelines .civitai_checkpoint import (
88 CIVITAI_BASE_MODEL_PIPELINE_MAP ,
99 DEFAULT_PIPELINE_CONFIG ,
10+ SCHEDULER_CHOICES ,
11+ SCHEDULER_MAP ,
1012 CivitaiCheckpointPipeline ,
1113 PipelineConfig ,
1214 get_diffusers_pipeline_class ,
@@ -258,8 +260,9 @@ def test_load_file_not_found(self, tmp_path):
258260 with pytest .raises (FileNotFoundError , match = "Checkpoint not found" ):
259261 pipeline .load ({"checkpoint_path" : str (tmp_path / "nonexistent.safetensors" )})
260262
263+ @patch .object (CivitaiCheckpointPipeline , "configure_scheduler" )
261264 @patch ("oneiro.pipelines.civitai_checkpoint.get_diffusers_pipeline_class" )
262- def test_load_with_base_model_override (self , mock_get_class , tmp_path ):
265+ def test_load_with_base_model_override (self , mock_get_class , mock_config_sched , tmp_path ):
263266 """load() uses base_model override from config."""
264267 # Create dummy checkpoint file
265268 checkpoint = tmp_path / "model.safetensors"
@@ -307,8 +310,9 @@ def test_load_with_pipeline_class_override(self, mock_get_class, tmp_path):
307310 assert pipeline ._pipeline_config .default_guidance_scale == 5.0
308311 mock_get_class .assert_called_once_with ("CustomPipeline" )
309312
313+ @patch .object (CivitaiCheckpointPipeline , "configure_scheduler" )
310314 @patch ("oneiro.pipelines.civitai_checkpoint.get_diffusers_pipeline_class" )
311- def test_load_enables_cpu_offload (self , mock_get_class , tmp_path ):
315+ def test_load_enables_cpu_offload (self , mock_get_class , mock_config_sched , tmp_path ):
312316 """load() enables CPU offload when configured."""
313317 checkpoint = tmp_path / "model.safetensors"
314318 checkpoint .write_bytes (b"dummy" )
@@ -329,8 +333,9 @@ def test_load_enables_cpu_offload(self, mock_get_class, tmp_path):
329333
330334 mock_pipe .enable_model_cpu_offload .assert_called_once ()
331335
336+ @patch .object (CivitaiCheckpointPipeline , "configure_scheduler" )
332337 @patch ("oneiro.pipelines.civitai_checkpoint.get_diffusers_pipeline_class" )
333- def test_load_enables_vae_optimizations (self , mock_get_class , tmp_path ):
338+ def test_load_enables_vae_optimizations (self , mock_get_class , mock_config_sched , tmp_path ):
334339 """load() enables VAE tiling and slicing."""
335340 checkpoint = tmp_path / "model.safetensors"
336341 checkpoint .write_bytes (b"dummy" )
@@ -620,3 +625,212 @@ def test_generate_handles_img2img(self):
620625 assert call_kwargs ["strength" ] == 0.5
621626 assert "width" not in call_kwargs
622627 assert "height" not in call_kwargs
628+
629+ def test_generate_with_scheduler_override (self ):
630+ pipeline = CivitaiCheckpointPipeline ()
631+ pipeline ._pipeline_config = PipelineConfig (
632+ pipeline_class = "StableDiffusionXLPipeline" ,
633+ default_scheduler = "dpm++_karras" ,
634+ )
635+
636+ mock_pipe = MagicMock ()
637+ mock_scheduler = MagicMock ()
638+ mock_scheduler .config = {}
639+ mock_pipe .scheduler = mock_scheduler
640+ mock_image = MagicMock ()
641+ mock_image .width = 1024
642+ mock_image .height = 1024
643+ mock_pipe .return_value .images = [mock_image ]
644+ pipeline .pipe = mock_pipe
645+
646+ mock_euler = MagicMock ()
647+ with (
648+ patch ("oneiro.pipelines.civitai_checkpoint.torch" ),
649+ patch ("diffusers.EulerAncestralDiscreteScheduler" , mock_euler ),
650+ ):
651+ pipeline .generate ("test prompt" , scheduler = "euler_a" )
652+
653+ mock_euler .from_config .assert_called_once ()
654+
655+
656+ class TestSchedulerMap :
657+ def test_scheduler_choices_matches_map_keys (self ):
658+ assert set (SCHEDULER_CHOICES ) == set (SCHEDULER_MAP .keys ())
659+
660+ def test_default_entry_has_none_class (self ):
661+ class_name , kwargs = SCHEDULER_MAP ["default" ]
662+ assert class_name is None
663+ assert kwargs == {}
664+
665+ def test_dpm_karras_entry (self ):
666+ class_name , kwargs = SCHEDULER_MAP ["dpm++_karras" ]
667+ assert class_name == "DPMSolverMultistepScheduler"
668+ assert kwargs ["algorithm_type" ] == "sde-dpmsolver++"
669+ assert kwargs ["use_karras_sigmas" ] is True
670+
671+ def test_dpm_entry (self ):
672+ class_name , kwargs = SCHEDULER_MAP ["dpm++" ]
673+ assert class_name == "DPMSolverMultistepScheduler"
674+ assert kwargs ["algorithm_type" ] == "sde-dpmsolver++"
675+ assert kwargs ["use_karras_sigmas" ] is False
676+
677+ def test_euler_a_entry (self ):
678+ class_name , kwargs = SCHEDULER_MAP ["euler_a" ]
679+ assert class_name == "EulerAncestralDiscreteScheduler"
680+ assert kwargs == {}
681+
682+ def test_euler_entry (self ):
683+ class_name , kwargs = SCHEDULER_MAP ["euler" ]
684+ assert class_name == "EulerDiscreteScheduler"
685+ assert kwargs == {}
686+
687+ def test_heun_entry (self ):
688+ class_name , kwargs = SCHEDULER_MAP ["heun" ]
689+ assert class_name == "HeunDiscreteScheduler"
690+ assert kwargs == {}
691+
692+ def test_ddim_entry (self ):
693+ class_name , kwargs = SCHEDULER_MAP ["ddim" ]
694+ assert class_name == "DDIMScheduler"
695+ assert kwargs == {}
696+
697+
698+ class TestDefaultSchedulers :
699+ def test_sd15_has_dpm_karras (self ):
700+ config = CIVITAI_BASE_MODEL_PIPELINE_MAP ["SD 1.5" ]
701+ assert config .default_scheduler == "dpm++_karras"
702+
703+ def test_sdxl_has_dpm_karras (self ):
704+ config = CIVITAI_BASE_MODEL_PIPELINE_MAP ["SDXL 1.0" ]
705+ assert config .default_scheduler == "dpm++_karras"
706+
707+ def test_pony_has_dpm_karras (self ):
708+ config = CIVITAI_BASE_MODEL_PIPELINE_MAP ["Pony" ]
709+ assert config .default_scheduler == "dpm++_karras"
710+
711+ def test_illustrious_has_dpm_karras (self ):
712+ config = CIVITAI_BASE_MODEL_PIPELINE_MAP ["Illustrious" ]
713+ assert config .default_scheduler == "dpm++_karras"
714+
715+ def test_sdxl_turbo_keeps_default (self ):
716+ config = CIVITAI_BASE_MODEL_PIPELINE_MAP ["SDXL Turbo" ]
717+ assert config .default_scheduler == "default"
718+
719+ def test_sdxl_lightning_keeps_default (self ):
720+ config = CIVITAI_BASE_MODEL_PIPELINE_MAP ["SDXL Lightning" ]
721+ assert config .default_scheduler == "default"
722+
723+ def test_flux_keeps_default (self ):
724+ config = CIVITAI_BASE_MODEL_PIPELINE_MAP ["Flux.1 Dev" ]
725+ assert config .default_scheduler == "default"
726+
727+ def test_sd3_keeps_default (self ):
728+ config = CIVITAI_BASE_MODEL_PIPELINE_MAP ["SD 3" ]
729+ assert config .default_scheduler == "default"
730+
731+
732+ class TestConfigureScheduler :
733+ def test_configure_scheduler_with_none_uses_pipeline_default (self ):
734+ pipeline = CivitaiCheckpointPipeline ()
735+ pipeline ._pipeline_config = PipelineConfig (
736+ pipeline_class = "StableDiffusionXLPipeline" ,
737+ default_scheduler = "dpm++_karras" ,
738+ )
739+ mock_pipe = MagicMock ()
740+ mock_scheduler = MagicMock ()
741+ mock_scheduler .config = {}
742+ mock_pipe .scheduler = mock_scheduler
743+ pipeline .pipe = mock_pipe
744+
745+ mock_dpm = MagicMock ()
746+ with patch ("diffusers.DPMSolverMultistepScheduler" , mock_dpm ):
747+ pipeline .configure_scheduler (None )
748+
749+ mock_dpm .from_config .assert_called_once ()
750+
751+ def test_configure_scheduler_default_string_no_change (self ):
752+ pipeline = CivitaiCheckpointPipeline ()
753+ pipeline ._pipeline_config = PipelineConfig (
754+ pipeline_class = "StableDiffusionXLPipeline" ,
755+ default_scheduler = "default" ,
756+ )
757+ mock_pipe = MagicMock ()
758+ original_scheduler = MagicMock ()
759+ mock_pipe .scheduler = original_scheduler
760+ pipeline .pipe = mock_pipe
761+
762+ pipeline .configure_scheduler ("default" )
763+
764+ assert mock_pipe .scheduler is original_scheduler
765+
766+ def test_configure_scheduler_unknown_warns (self , capsys ):
767+ pipeline = CivitaiCheckpointPipeline ()
768+ pipeline ._pipeline_config = PipelineConfig (
769+ pipeline_class = "StableDiffusionXLPipeline" ,
770+ )
771+ mock_pipe = MagicMock ()
772+ original_scheduler = MagicMock ()
773+ mock_pipe .scheduler = original_scheduler
774+ pipeline .pipe = mock_pipe
775+
776+ pipeline .configure_scheduler ("unknown_scheduler" )
777+
778+ captured = capsys .readouterr ()
779+ assert "Unknown scheduler" in captured .out
780+ assert mock_pipe .scheduler is original_scheduler
781+
782+ def test_configure_scheduler_euler_a (self ):
783+ pipeline = CivitaiCheckpointPipeline ()
784+ pipeline ._pipeline_config = PipelineConfig (
785+ pipeline_class = "StableDiffusionXLPipeline" ,
786+ )
787+ mock_pipe = MagicMock ()
788+ mock_scheduler = MagicMock ()
789+ mock_scheduler .config = {}
790+ mock_pipe .scheduler = mock_scheduler
791+ pipeline .pipe = mock_pipe
792+
793+ mock_euler = MagicMock ()
794+ with patch ("diffusers.EulerAncestralDiscreteScheduler" , mock_euler ):
795+ pipeline .configure_scheduler ("euler_a" )
796+
797+ mock_euler .from_config .assert_called_once_with ({})
798+
799+ def test_configure_scheduler_with_kwargs (self ):
800+ pipeline = CivitaiCheckpointPipeline ()
801+ pipeline ._pipeline_config = PipelineConfig (
802+ pipeline_class = "StableDiffusionXLPipeline" ,
803+ )
804+ mock_pipe = MagicMock ()
805+ mock_scheduler = MagicMock ()
806+ mock_scheduler .config = {"some" : "config" }
807+ mock_pipe .scheduler = mock_scheduler
808+ pipeline .pipe = mock_pipe
809+
810+ mock_dpm = MagicMock ()
811+ with patch ("diffusers.DPMSolverMultistepScheduler" , mock_dpm ):
812+ pipeline .configure_scheduler ("dpm++_karras" )
813+
814+ mock_dpm .from_config .assert_called_once_with (
815+ {"some" : "config" },
816+ algorithm_type = "sde-dpmsolver++" ,
817+ use_karras_sigmas = True ,
818+ )
819+
820+ def test_configure_scheduler_skips_redundant_reconfiguration (self ):
821+ pipeline = CivitaiCheckpointPipeline ()
822+ pipeline ._pipeline_config = PipelineConfig (
823+ pipeline_class = "StableDiffusionXLPipeline" ,
824+ )
825+ mock_pipe = MagicMock ()
826+ mock_scheduler = MagicMock ()
827+ mock_scheduler .config = {}
828+ mock_pipe .scheduler = mock_scheduler
829+ pipeline .pipe = mock_pipe
830+
831+ mock_euler = MagicMock ()
832+ with patch ("diffusers.EulerAncestralDiscreteScheduler" , mock_euler ):
833+ pipeline .configure_scheduler ("euler_a" )
834+ pipeline .configure_scheduler ("euler_a" )
835+
836+ mock_euler .from_config .assert_called_once ()
0 commit comments