Skip to content

Commit a633289

Browse files
committed
update
1 parent 06fd427 commit a633289

File tree

3 files changed

+36
-5
lines changed

3 files changed

+36
-5
lines changed

src/diffusers/commands/custom_blocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from . import BaseDiffusersCLICommand
2828

2929

30-
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
30+
EXPECTED_PARENT_CLASSES = ["PipelineBlock"]
3131
CONFIG = "config.json"
3232

3333

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -351,14 +351,17 @@ def from_pretrained(
351351
"token",
352352
]
353353
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
354+
hub_kwargs.update({"trust_remote_code": trust_remote_code})
354355

355356
config = cls.load_config(pretrained_model_name_or_path)
356357
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
357358
trust_remote_code = resolve_trust_remote_code(
358359
trust_remote_code, pretrained_model_name_or_path, has_remote_code
359360
)
360361
if not (has_remote_code and trust_remote_code):
361-
raise ValueError("TODO")
362+
raise ValueError(
363+
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
364+
)
362365

363366
class_ref = config["auto_map"][cls.__name__]
364367
module_file, class_name = class_ref.split(".")
@@ -367,7 +370,6 @@ def from_pretrained(
367370
pretrained_model_name_or_path,
368371
module_file=module_file,
369372
class_name=class_name,
370-
is_modular=True,
371373
**hub_kwargs,
372374
**kwargs,
373375
)
@@ -384,10 +386,39 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
384386

385387
full_mod = type(self).__module__
386388
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
387-
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
389+
parent_module = self.__class__.__bases__[0].__name__
388390
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
389391

390392
self.register_to_config(auto_map=auto_map)
393+
394+
_component_specs = {spec.name: deepcopy(spec) for spec in self.expected_components}
395+
_config_specs = {spec.name: deepcopy(spec) for spec in self.expected_configs}
396+
397+
register_components_dict = {}
398+
for name, component_spec in _component_specs.items():
399+
if component_spec.type_hint is not None:
400+
lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint)
401+
else:
402+
lib_name = cls_name = None
403+
load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()}
404+
405+
# Since ModularPipelineBlocks can never have loaded components we set
406+
# first two fields in the config dict to None
407+
register_components_dict[name] = (
408+
None,
409+
None,
410+
{
411+
"type_hint": (lib_name, cls_name),
412+
**load_spec_dict,
413+
},
414+
)
415+
self.register_to_config(**register_components_dict)
416+
417+
default_configs = {}
418+
for name, config_spec in _config_specs.items():
419+
default_configs[name] = config_spec.default
420+
self.register_to_config(**default_configs)
421+
391422
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
392423
config = dict(self.config)
393424
self._internal_dict = FrozenDict(config)

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ class ComponentSpec:
9393
config: Optional[FrozenDict] = None
9494
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
9595
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
96-
subfolder: Optional[str] = field(default=None, metadata={"loading": True})
96+
subfolder: Optional[str] = field(default="", metadata={"loading": True})
9797
variant: Optional[str] = field(default=None, metadata={"loading": True})
9898
revision: Optional[str] = field(default=None, metadata={"loading": True})
9999
default_creation_method: Literal["from_config", "from_pretrained"] = "from_pretrained"

0 commit comments

Comments
 (0)