1010from pathlib import Path
1111from typing import TYPE_CHECKING , Any
1212
13- import torch
14-
13+ from oneiro .device import DevicePolicy , OffloadMode
1514from oneiro .pipelines .base import BasePipeline , GenerationResult
1615from oneiro .pipelines .embedding import EmbeddingLoaderMixin , parse_embeddings_from_config
1716from oneiro .pipelines .long_prompt import (
@@ -619,25 +618,26 @@ def _load_from_path(self, checkpoint_path: Path, model_config: dict[str, Any]) -
619618 print (f" Base model: { base_model or 'unknown' } " )
620619 print (f" Pipeline: { self ._pipeline_config .pipeline_class } " )
621620
621+ cpu_offload = model_config .get ("cpu_offload" , True )
622+ self .policy = DevicePolicy .auto_detect (cpu_offload = cpu_offload )
623+
622624 # Get the pipeline class
623625 pipeline_class = get_diffusers_pipeline_class (self ._pipeline_config .pipeline_class )
624626
625627 # Load from single file
626628 self .pipe = pipeline_class .from_single_file (
627629 str (checkpoint_path ),
628- torch_dtype = self ._dtype ,
630+ torch_dtype = self .policy . dtype ,
629631 )
630632
631633 scheduler_override = model_config .get ("scheduler" )
632634 self .configure_scheduler (scheduler_override )
633635
634- # Apply optimizations
635- cpu_offload = model_config .get ("cpu_offload" , True )
636- self ._cpu_offload = cpu_offload and self ._device == "cuda"
637- if self ._cpu_offload :
638- self .pipe .enable_model_cpu_offload ()
639- elif self ._device == "cuda" :
640- self .pipe .to ("cuda" )
636+ self .policy .apply_to_pipeline (self .pipe )
637+ # Track whether offload was applied (for dynamic LoRA handling)
638+ self ._cpu_offload = (
639+ self .policy .offload != OffloadMode .NEVER and self .policy .device == "cuda"
640+ )
641641
642642 # Enable memory optimizations for VAE if available
643643 if hasattr (self .pipe , "vae" ):
@@ -859,8 +859,7 @@ def _run_generation(
859859
860860 result = self .pipe (** gen_kwargs )
861861
862- if torch .cuda .is_available ():
863- torch .cuda .empty_cache ()
862+ DevicePolicy .clear_cache ()
864863
865864 output_image = result .images [0 ]
866865 return GenerationResult (
@@ -887,7 +886,7 @@ def _load_dynamic_loras(self, loras: list[LoraConfig]) -> None:
887886 # Only move pipeline to device manually when CPU offload is not enabled.
888887 # With CPU offload, diffusers manages device placement automatically.
889888 if not self ._cpu_offload :
890- self .pipe .to (self ._device )
889+ self .pipe .to (self .policy . device )
891890
892891 loaded_names : list [str ] = []
893892 loaded_weights : list [float ] = []
@@ -911,7 +910,7 @@ def _restore_static_loras(self) -> None:
911910 return
912911
913912 if not self ._cpu_offload :
914- self .pipe .to (self ._device )
913+ self .pipe .to (self .policy . device )
915914 self .load_loras_sync (self ._static_lora_configs )
916915 print (f"Restored { len (self ._static_lora_configs )} static LoRA(s)" )
917916
0 commit comments