Skip to content
Open
2 changes: 0 additions & 2 deletions src/diffusers/commands/custom_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ def run(self):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")

def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:
Expand Down
62 changes: 51 additions & 11 deletions src/diffusers/modular_pipelines/modular_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from ..utils import PushToHubMixin, is_accelerate_available, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.import_utils import _is_package_available
from .components_manager import ComponentsManager
from .modular_pipeline_utils import (
ComponentSpec,
Expand Down Expand Up @@ -233,6 +234,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):

config_name = "modular_config.json"
model_name = None
_requirements: Union[List[Tuple[str, str]], Tuple[str, str]] = None

@classmethod
def _get_signature_keys(cls, obj):
Expand Down Expand Up @@ -295,6 +297,19 @@ def from_pretrained(
trust_remote_code: bool = False,
**kwargs,
):
config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not (has_remote_code and trust_remote_code):
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)

if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])

hub_kwargs_names = [
"cache_dir",
"force_download",
Expand All @@ -307,16 +322,6 @@ def from_pretrained(
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}

config = cls.load_config(pretrained_model_name_or_path)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if not has_remote_code and trust_remote_code:
raise ValueError(
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)

class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
Expand All @@ -342,8 +347,13 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}

self.register_to_config(auto_map=auto_map)

# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)

self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
Expand Down Expand Up @@ -2529,3 +2539,33 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
return state.get(output)
else:
raise ValueError(f"Output '{output}' is not a valid output type")


def _validate_requirements(reqs):
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return []

final: List[Tuple[str, str]] = []
for req, specified_ver in normalized_reqs:
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
raise ValueError(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver != req_actual_ver:
logger.warning(
f"Version of {req} was specified to be {specified_ver} in the configuration. However, the actual installed version if {req_actual_ver}. Things might work unexpected."
)

final.append((req, specified_ver))

return final


def _normalize_requirements(reqs):
if not reqs:
return []
if isinstance(reqs, tuple) and len(reqs) == 2 and isinstance(reqs[0], str):
req_seq: List[Tuple[str, str]] = [reqs] # single pair
else:
req_seq = reqs
return req_seq
Loading