Skip to content

Commit ffbaa89

Browse files
committed
move save_pretrained to the correct place
1 parent e49413d commit ffbaa89

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def format_value(v):
248248

249249
class ModularPipelineBlocks(ConfigMixin):
250250
"""
251-
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
251+
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks, LoopSequentialPipelineBlocks
252252
"""
253253

254254
config_name = "config.json"
@@ -307,6 +307,20 @@ def from_pretrained(
307307
}
308308

309309
return block_cls(**block_kwargs)
310+
311+
def save_pretrained(self, save_directory, push_to_hub = False, **kwargs):
312+
# TODO: factor out this logic.
313+
cls_name = self.__class__.__name__
314+
315+
full_mod = type(self).__module__
316+
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
317+
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
318+
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
319+
320+
self.register_to_config(auto_map=auto_map)
321+
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
322+
config = dict(self.config)
323+
self._internal_dict = FrozenDict(config)
310324

311325
def init_pipeline(self, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
312326
"""
@@ -532,21 +546,6 @@ def add_block_state(self, state: PipelineState, block_state: BlockState):
532546
if current_value is not param: # Using identity comparison to check if object was modified
533547
state.add_intermediate(param_name, param, input_param.kwargs_type)
534548

535-
def save_pretrained(self, save_directory, push_to_hub = False, **kwargs):
536-
# TODO: factor out this logic.
537-
cls_name = self.__class__.__name__
538-
539-
full_mod = type(self).__module__
540-
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
541-
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
542-
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
543-
_component_names = [c.name for c in self.expected_components]
544-
545-
self.register_to_config(auto_map=auto_map, _component_names=_component_names)
546-
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
547-
config = dict(self.config)
548-
self._internal_dict = FrozenDict(config)
549-
550549

551550
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
552551
"""
@@ -2366,9 +2365,7 @@ def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader):
23662365
self.loader = loader
23672366

23682367
def __repr__(self):
2369-
blocks_class = self.blocks.__class__.__name__
2370-
loader_class = self.loader.__class__.__name__
2371-
return f"ModularPipeline(blocks={blocks_class}, loader={loader_class})"
2368+
return f"ModularPipeline(\n blocks={repr(self.blocks)},\n loader={repr(self.loader)}\n)"
23722369

23732370
@property
23742371
def default_call_parameters(self) -> Dict[str, Any]:

0 commit comments

Comments
 (0)