Skip to content

Commit 12ceecf

Browse files
committed
feat: implement requirements validation for custom blocks.
1 parent 7a2b78b commit 12ceecf

File tree

2 files changed

+55
-13
lines changed

2 files changed

+55
-13
lines changed

src/diffusers/commands/custom_blocks.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,6 @@ def run(self):
8989
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
9090
# with open(CONFIG, "w") as f:
9191
# json.dump(automap, f)
92-
with open("requirements.txt", "w") as f:
93-
f.write("")
9492

9593
def _choose_block(self, candidates, chosen=None):
9694
for cls, base in candidates:

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from ..utils import PushToHubMixin, is_accelerate_available, logging
3333
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
3434
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
35+
from ..utils.import_utils import _is_package_available
3536
from .components_manager import ComponentsManager
3637
from .modular_pipeline_utils import (
3738
ComponentSpec,
@@ -231,6 +232,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
231232

232233
config_name = "modular_config.json"
233234
model_name = None
235+
_requirements: Union[List[Tuple[str, str], Tuple[str, str]]] = None
234236

235237
@classmethod
236238
def _get_signature_keys(cls, obj):
@@ -270,6 +272,28 @@ def _get_required_inputs(self):
270272

271273
return input_names
272274

275+
def _get_requirements(self):
276+
if getattr(self, "_requirements", None) is not None:
277+
defined_reqs = self._requirements
278+
if not isinstance(defined_reqs):
279+
defined_reqs = [defined_reqs]
280+
281+
final_reqs = []
282+
for pkg, specified_ver in defined_reqs:
283+
pkg_available, pkg_actual_ver = _is_package_available(pkg)
284+
if not pkg_available:
285+
raise ValueError(
286+
f"{pkg} was specified in the requirements but wasn't found. Please check your environment."
287+
)
288+
if specified_ver != pkg_actual_ver:
289+
logger.warning(
290+
f"Version for {pkg} was specified to be {specified_ver} whereas the actual version found is {pkg_actual_ver}. Ignore if this is not concerning."
291+
)
292+
final_reqs.append((pkg, specified_ver))
293+
294+
else:
295+
return None
296+
273297
@property
274298
def required_inputs(self) -> List[InputParam]:
275299
return self._get_required_inputs()
@@ -293,6 +317,31 @@ def from_pretrained(
293317
trust_remote_code: Optional[bool] = None,
294318
**kwargs,
295319
):
320+
config = cls.load_config(pretrained_model_name_or_path)
321+
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
322+
trust_remote_code = resolve_trust_remote_code(
323+
trust_remote_code, pretrained_model_name_or_path, has_remote_code
324+
)
325+
if not (has_remote_code and trust_remote_code):
326+
raise ValueError(
327+
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
328+
)
329+
330+
if "requirements" in config and config["requirements"] is not None:
331+
requirements: Union[List[Tuple[str, str]], Tuple[str, str]] = config["requirements"]
332+
if not isinstance(requirements, list):
333+
requirements = [requirements]
334+
for pkg, fetched_ver in requirements:
335+
pkg_available, pkg_actual_ver = _is_package_available(pkg)
336+
if not pkg_available:
337+
raise ValueError(
338+
f"{pkg} was specified in the requirements but wasn't found in the current environment."
339+
)
340+
if fetched_ver != pkg_actual_ver:
341+
logger.warning(
342+
f"Version of {pkg} was specified to be {fetched_ver} in the configuration. However, the actual installed version if {pkg_actual_ver}. Things might work unexpected."
343+
)
344+
296345
hub_kwargs_names = [
297346
"cache_dir",
298347
"force_download",
@@ -305,16 +354,6 @@ def from_pretrained(
305354
]
306355
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
307356

308-
config = cls.load_config(pretrained_model_name_or_path)
309-
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
310-
trust_remote_code = resolve_trust_remote_code(
311-
trust_remote_code, pretrained_model_name_or_path, has_remote_code
312-
)
313-
if not (has_remote_code and trust_remote_code):
314-
raise ValueError(
315-
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
316-
)
317-
318357
class_ref = config["auto_map"][cls.__name__]
319358
module_file, class_name = class_ref.split(".")
320359
module_file = module_file + ".py"
@@ -340,8 +379,13 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
340379
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
341380
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
342381
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
343-
344382
self.register_to_config(auto_map=auto_map)
383+
384+
# resolve requirements
385+
requirements = self._get_requirements()
386+
if requirements is not None:
387+
self.register_to_config(requirements=requirements)
388+
345389
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
346390
config = dict(self.config)
347391
self._internal_dict = FrozenDict(config)

0 commit comments

Comments
 (0)