4848 format_inputs_short ,
4949 format_intermediates_short ,
5050 make_doc_string ,
51+ InsertableOrderedDict
5152)
5253from .components_manager import ComponentsManager
5354from ..utils .dynamic_modules_utils import get_class_from_dynamic_module , resolve_trust_remote_code
6667)
6768
6869
70+
6971@dataclass
7072class PipelineState :
7173 """
@@ -622,7 +624,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
622624 block_trigger_inputs = []
623625
624626 def __init__ (self ):
625- blocks = OrderedDict ()
627+ blocks = InsertableOrderedDict ()
626628 for block_name , block_cls in zip (self .block_names , self .block_classes ):
627629 blocks [block_name ] = block_cls ()
628630 self .blocks = blocks
@@ -958,7 +960,7 @@ def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "SequentialPipelineBlo
958960 return instance
959961
960962 def __init__ (self ):
961- blocks = OrderedDict ()
963+ blocks = InsertableOrderedDict ()
962964 for block_name , block_cls in zip (self .block_names , self .block_classes ):
963965 blocks [block_name ] = block_cls ()
964966 self .blocks = blocks
@@ -1449,7 +1451,7 @@ def outputs(self) -> List[str]:
14491451
14501452
14511453 def __init__ (self ):
1452- blocks = OrderedDict ()
1454+ blocks = InsertableOrderedDict ()
14531455 for block_name , block_cls in zip (self .block_names , self .block_classes ):
14541456 blocks [block_name ] = block_cls ()
14551457 self .blocks = blocks
@@ -1662,6 +1664,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin):
16621664
16631665 """
16641666 config_name = "modular_model_index.json"
1667+ hf_device_map = None
16651668
16661669
16671670 def register_components (self , ** kwargs ):
@@ -2013,7 +2016,26 @@ def load(self, component_names: Optional[List[str]] = None, **kwargs):
20132016 # Register all components at once
20142017 self .register_components (** components_to_register )
20152018
2016- # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
2019+ # Copied from diffusers.pipelines.pipeline_utils.DiffusionPipeline._maybe_raise_error_if_group_offload_active
2020+ def _maybe_raise_error_if_group_offload_active (
2021+ self , raise_error : bool = False , module : Optional [torch .nn .Module ] = None
2022+ ) -> bool :
2023+ from ..hooks .group_offloading import _is_group_offload_enabled
2024+
2025+ components = self .components .values () if module is None else [module ]
2026+ components = [component for component in components if isinstance (component , torch .nn .Module )]
2027+ for component in components :
2028+ if _is_group_offload_enabled (component ):
2029+ if raise_error :
2030+ raise ValueError (
2031+ "You are trying to apply model/sequential CPU offloading to a pipeline that contains components "
2032+ "with group offloading enabled. This is not supported. Please disable group offloading for "
2033+ "components of the pipeline to use other offloading methods."
2034+ )
2035+ return True
2036+ return False
2037+
2038+ # Modified from diffusers.pipelines.pipeline_utils.DiffusionPipeline.to
20172039 def to (self , * args , ** kwargs ) -> Self :
20182040 r"""
20192041 Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
@@ -2050,6 +2072,10 @@ def to(self, *args, **kwargs) -> Self:
20502072 Returns:
20512073 [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
20522074 """
2075+ from ..pipelines .pipeline_utils import _check_bnb_status , DiffusionPipeline
2076+ from ..utils import is_accelerate_available , is_accelerate_version , is_hpu_available , is_transformers_version
2077+
2078+
20532079 dtype = kwargs .pop ("dtype" , None )
20542080 device = kwargs .pop ("device" , None )
20552081 silence_dtype_warnings = kwargs .pop ("silence_dtype_warnings" , False )
@@ -2152,8 +2178,7 @@ def module_is_offloaded(module):
21522178 os .environ ["PT_HPU_MAX_COMPOUND_OP_SIZE" ] = "1"
21532179 logger .debug ("Environment variable set: PT_HPU_MAX_COMPOUND_OP_SIZE=1" )
21542180
2155- module_names , _ = self ._get_signature_keys (self )
2156- modules = [getattr (self , n , None ) for n in module_names ]
2181+ modules = self .components .values ()
21572182 modules = [m for m in modules if isinstance (m , torch .nn .Module )]
21582183
21592184 is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
@@ -2431,4 +2456,12 @@ def save_pretrained(self, save_directory: Optional[Union[str, os.PathLike]] = No
24312456
24322457 @property
24332458 def doc (self ):
2434- return self .blocks .doc
2459+ return self .blocks .doc
2460+
2461+ def to (self , * args , ** kwargs ):
2462+ self .loader .to (* args , ** kwargs )
2463+ return self
2464+
2465+ @property
2466+ def components (self ):
2467+ return self .loader .components
0 commit comments