Skip to content

Commit 1de4402

Browse files
committed
up
1 parent 024c2b9 commit 1de4402

File tree

2 files changed

+95
-32
lines changed

2 files changed

+95
-32
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@
3232
from ..utils import PushToHubMixin, is_accelerate_available, logging
3333
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
3434
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
35-
from ..utils.import_utils import _is_package_available
3635
from .components_manager import ComponentsManager
3736
from .modular_pipeline_utils import (
3837
ComponentSpec,
3938
ConfigSpec,
4039
InputParam,
4140
InsertableDict,
4241
OutputParam,
42+
_validate_requirements,
4343
format_components,
4444
format_configs,
4545
make_doc_string,
@@ -240,7 +240,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
240240

241241
config_name = "modular_config.json"
242242
model_name = None
243-
_requirements: Union[List[Tuple[str, str]], Tuple[str, str]] = None
243+
_requirements: Optional[Dict[str, str]] = None
244244

245245
@classmethod
246246
def _get_signature_keys(cls, obj):
@@ -1143,6 +1143,14 @@ def doc(self):
11431143
expected_configs=self.expected_configs,
11441144
)
11451145

1146+
@property
1147+
def _requirements(self) -> Dict[str, str]:
1148+
requirements = {}
1149+
for block_name, block in self.sub_blocks.items():
1150+
if getattr(block, "_requirements", None):
1151+
requirements[block_name] = block._requirements
1152+
return requirements
1153+
11461154

11471155
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
11481156
"""
@@ -2547,33 +2555,3 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
25472555
return state.get(output)
25482556
else:
25492557
raise ValueError(f"Output '{output}' is not a valid output type")
2550-
2551-
2552-
def _validate_requirements(reqs):
2553-
normalized_reqs = _normalize_requirements(reqs)
2554-
if not normalized_reqs:
2555-
return []
2556-
2557-
final: List[Tuple[str, str]] = []
2558-
for req, specified_ver in normalized_reqs:
2559-
req_available, req_actual_ver = _is_package_available(req)
2560-
if not req_available:
2561-
raise ValueError(f"{req} was specified in the requirements but wasn't found in the current environment.")
2562-
if specified_ver != req_actual_ver:
2563-
logger.warning(
2564-
f"Version of {req} was specified to be {specified_ver} in the configuration. However, the actual installed version if {req_actual_ver}. Things might work unexpected."
2565-
)
2566-
2567-
final.append((req, specified_ver))
2568-
2569-
return final
2570-
2571-
2572-
def _normalize_requirements(reqs):
2573-
if not reqs:
2574-
return []
2575-
if isinstance(reqs, tuple) and len(reqs) == 2 and isinstance(reqs[0], str):
2576-
req_seq: List[Tuple[str, str]] = [reqs] # single pair
2577-
else:
2578-
req_seq = reqs
2579-
return req_seq

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
from typing import Any, Dict, List, Literal, Optional, Type, Union
2020

2121
import torch
22+
from packaging.specifiers import InvalidSpecifier, SpecifierSet
2223

2324
from ..configuration_utils import ConfigMixin, FrozenDict
2425
from ..utils import is_torch_available, logging
26+
from ..utils.import_utils import _is_package_available
2527

2628

2729
if is_torch_available():
@@ -670,3 +672,86 @@ def make_doc_string(
670672
output += format_output_params(outputs, indent_level=2)
671673

672674
return output
675+
676+
677+
def _validate_requirements(reqs):
678+
if reqs is None:
679+
normalized_reqs = {}
680+
else:
681+
if not isinstance(reqs, dict):
682+
raise ValueError(
683+
"Requirements must be provided as a dictionary mapping package names to version specifiers."
684+
)
685+
normalized_reqs = _normalize_requirements(reqs)
686+
687+
if not normalized_reqs:
688+
return {}
689+
690+
final: Dict[str, str] = {}
691+
for req, specified_ver in normalized_reqs.items():
692+
req_available, req_actual_ver = _is_package_available(req)
693+
if not req_available:
694+
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
695+
696+
if specified_ver:
697+
try:
698+
specifier = SpecifierSet(specified_ver)
699+
except InvalidSpecifier as err:
700+
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
701+
702+
if req_actual_ver == "N/A":
703+
logger.warning(
704+
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
705+
)
706+
elif not specifier.contains(req_actual_ver, prereleases=True):
707+
logger.warning(
708+
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
709+
)
710+
711+
final[req] = specified_ver
712+
713+
return final
714+
715+
716+
def _normalize_requirements(reqs):
717+
if not reqs:
718+
return {}
719+
720+
normalized: "OrderedDict[str, str]" = OrderedDict()
721+
722+
def _accumulate(mapping: Dict[str, Any]):
723+
for pkg, spec in mapping.items():
724+
if isinstance(spec, dict):
725+
# This is recursive because blocks are composable. This way, we can merge requirements
726+
# from multiple blocks.
727+
_accumulate(spec)
728+
continue
729+
730+
pkg_name = str(pkg).strip()
731+
if not pkg_name:
732+
raise ValueError("Requirement package name cannot be empty.")
733+
734+
spec_str = "" if spec is None else str(spec).strip()
735+
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
736+
spec_str = f"=={spec_str}"
737+
738+
existing_spec = normalized.get(pkg_name)
739+
if existing_spec is not None:
740+
if not existing_spec and spec_str:
741+
normalized[pkg_name] = spec_str
742+
elif existing_spec and spec_str and existing_spec != spec_str:
743+
try:
744+
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
745+
except InvalidSpecifier:
746+
logger.warning(
747+
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
748+
)
749+
else:
750+
normalized[pkg_name] = str(combined_spec)
751+
continue
752+
753+
normalized[pkg_name] = spec_str
754+
755+
_accumulate(reqs)
756+
757+
return normalized

0 commit comments

Comments
 (0)