1010from pathlib import Path
1111from typing import TYPE_CHECKING , Any
1212
13+ import torch
14+ from PIL import Image
15+
1316from oneiro .device import DevicePolicy , OffloadMode
1417from oneiro .pipelines .base import BasePipeline , GenerationResult
1518from oneiro .pipelines .embedding import EmbeddingLoaderMixin , parse_embeddings_from_config
@@ -519,6 +522,7 @@ def __init__(self) -> None:
519522 self ._current_scheduler : str | None = None
520523 self ._static_lora_configs : list [LoraConfig ] = []
521524 self ._cpu_offload : bool = False
525+ self ._has_dynamic_loras : bool = False
522526
523527 def load (self , model_config : dict [str , Any ], full_config : dict [str , Any ] | None = None ) -> None :
524528 """Load checkpoint from config (synchronous, requires checkpoint_path).
@@ -692,6 +696,12 @@ def configure_scheduler(self, scheduler_name: str | None) -> None:
692696 self ._current_scheduler = scheduler_name
693697 print (f" Scheduler: { scheduler_name } " )
694698
699+ def validate_pipeline (self ) -> None :
700+ """Validate pipeline and config are ready for generation."""
701+ super ().validate_pipeline ()
702+ if self ._pipeline_config is None :
703+ raise RuntimeError ("Pipeline config not initialized" )
704+
695705 def generate (
696706 self ,
697707 prompt : str ,
@@ -762,84 +772,70 @@ def generate(
762772 as the actual seed used, prompts, final image size, number of
763773 steps, and guidance scale.
764774 """
765- if self .pipe is None :
766- raise RuntimeError ("Pipeline not loaded" )
775+ # Apply defaults from pipeline config (validation happens in super().generate())
776+ # Note: We need to check _pipeline_config here before applying defaults,
777+ # but full validation happens in validate_pipeline() called by super()
778+ if self ._pipeline_config is not None :
779+ width = width if width is not None else self ._pipeline_config .default_width
780+ height = height if height is not None else self ._pipeline_config .default_height
781+ steps = steps if steps is not None else self ._pipeline_config .default_steps
782+ guidance_scale = (
783+ guidance_scale
784+ if guidance_scale is not None
785+ else self ._pipeline_config .default_guidance_scale
786+ )
767787
768- if self ._pipeline_config is None :
769- raise RuntimeError ("Pipeline config not initialized" )
788+ return super ().generate (
789+ prompt = prompt ,
790+ negative_prompt = negative_prompt ,
791+ width = width or 1024 ,
792+ height = height or 1024 ,
793+ seed = seed ,
794+ steps = steps or 20 ,
795+ guidance_scale = guidance_scale if guidance_scale is not None else 7.0 ,
796+ ** kwargs ,
797+ )
770798
799+ def pre_generate (self , ** kwargs : Any ) -> None :
800+ """Pre-generation setup: scheduler override and dynamic LoRA loading."""
771801 scheduler_override = kwargs .pop ("scheduler" , None )
772802 if scheduler_override :
773803 self .configure_scheduler (scheduler_override )
774804
775805 dynamic_loras = kwargs .pop ("loras" , None )
776- has_dynamic_loras = False
806+ self . _has_dynamic_loras = False
777807 if dynamic_loras :
778- # Mark that we are entering a dynamic LoRA context before loading,
779- # so that failures during loading can be properly rolled back.
780- has_dynamic_loras = True
808+ self ._has_dynamic_loras = True
781809 try :
782810 self ._load_dynamic_loras (dynamic_loras )
783811 except Exception :
784- # If loading dynamic LoRAs fails after modifying the pipeline
785- # state (for example, after unloading static LoRAs), attempt
786- # to restore the original static LoRAs before propagating
787- # the error.
788812 self ._restore_static_loras ()
789- has_dynamic_loras = False
813+ self . _has_dynamic_loras = False
790814 raise
791815
792- try :
793- return self ._run_generation (
794- prompt = prompt ,
795- negative_prompt = negative_prompt ,
796- width = width ,
797- height = height ,
798- seed = seed ,
799- steps = steps ,
800- guidance_scale = guidance_scale ,
801- ** kwargs ,
802- )
803- finally :
804- if has_dynamic_loras :
805- self ._restore_static_loras ()
806-
807- def _run_generation (
816+ def build_generation_kwargs (
808817 self ,
809818 prompt : str ,
810819 negative_prompt : str | None ,
811- width : int | None ,
812- height : int | None ,
813- seed : int ,
814- steps : int | None ,
815- guidance_scale : float | None ,
820+ width : int ,
821+ height : int ,
822+ steps : int ,
823+ guidance_scale : float ,
824+ generator : "torch.Generator" ,
825+ init_image : "Image.Image | None" ,
826+ strength : float ,
816827 ** kwargs : Any ,
817- ) -> GenerationResult :
828+ ) -> dict [str , Any ]:
829+ """Build generation kwargs with embedding support."""
818830 assert self ._pipeline_config is not None
819- width = width or self ._pipeline_config .default_width
820- height = height or self ._pipeline_config .default_height
821- steps = steps or self ._pipeline_config .default_steps
822- guidance_scale = (
823- guidance_scale
824- if guidance_scale is not None
825- else self ._pipeline_config .default_guidance_scale
826- )
827-
828- actual_seed , generator = self ._prepare_seed (seed )
829831
830- # Handle img2img
831- init_image = self ._load_init_image (kwargs .get ("init_image" ))
832- strength = kwargs .get ("strength" , 0.75 )
833-
834- # Build generation kwargs
835832 gen_kwargs : dict [str , Any ] = {
836833 "num_inference_steps" : steps ,
837834 "guidance_scale" : guidance_scale ,
838835 "generator" : generator ,
839836 }
840837
841838 # Use embedding-based prompt handling for pipelines that support it
842- # (SD 1.x, SD 2.x, SDXL, Flux, SD3) - enables weight syntax like (word:1.5)
843839 if self ._supports_prompt_embeddings ():
844840 gen_kwargs .update (self ._encode_prompts_to_embeddings (prompt , negative_prompt ))
845841 else :
@@ -849,29 +845,21 @@ def _run_generation(
849845 gen_kwargs ["negative_prompt" ] = negative_prompt
850846
851847 if init_image :
852- print (f"CivitAI img2img: '{ prompt [:50 ]} ...' seed= { actual_seed } strength={ strength } " )
848+ print (f"CivitAI img2img: '{ prompt [:50 ]} ...' strength={ strength } " )
853849 gen_kwargs ["image" ] = init_image
854850 gen_kwargs ["strength" ] = strength
855851 else :
856- print (f"CivitAI generating: '{ prompt [:50 ]} ...' seed= { actual_seed } " )
852+ print (f"CivitAI generating: '{ prompt [:50 ]} ...'" )
857853 gen_kwargs ["height" ] = height
858854 gen_kwargs ["width" ] = width
859855
860- result = self .pipe (** gen_kwargs )
861-
862- DevicePolicy .clear_cache ()
856+ return gen_kwargs
863857
864- output_image = result .images [0 ]
865- return GenerationResult (
866- image = output_image ,
867- seed = actual_seed ,
868- prompt = prompt ,
869- negative_prompt = negative_prompt ,
870- width = output_image .width ,
871- height = output_image .height ,
872- steps = steps ,
873- guidance_scale = guidance_scale ,
874- )
858+ def post_generate (self , ** kwargs : Any ) -> None :
859+ """Post-generation cleanup: restore static LoRAs if dynamic were used."""
860+ if self ._has_dynamic_loras :
861+ self ._restore_static_loras ()
862+ self ._has_dynamic_loras = False
875863
876864 def _load_dynamic_loras (self , loras : list [LoraConfig ]) -> None :
877865 if self .pipe is None or not loras :
0 commit comments