3232from ..utils import PushToHubMixin , is_accelerate_available , logging
3333from ..utils .dynamic_modules_utils import get_class_from_dynamic_module , resolve_trust_remote_code
3434from ..utils .hub_utils import load_or_create_model_card , populate_model_card
35+ from ..utils .import_utils import _is_package_available
3536from .components_manager import ComponentsManager
3637from .modular_pipeline_utils import (
3738 ComponentSpec ,
@@ -231,6 +232,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
231232
232233 config_name = "modular_config.json"
233234 model_name = None
235+ _requirements : Union [List [Tuple [str , str ], Tuple [str , str ]]] = None
234236
235237 @classmethod
236238 def _get_signature_keys (cls , obj ):
@@ -270,6 +272,28 @@ def _get_required_inputs(self):
270272
271273 return input_names
272274
275+ def _get_requirements (self ):
276+ if getattr (self , "_requirements" , None ) is not None :
277+ defined_reqs = self ._requirements
278+ if not isinstance (defined_reqs ):
279+ defined_reqs = [defined_reqs ]
280+
281+ final_reqs = []
282+ for pkg , specified_ver in defined_reqs :
283+ pkg_available , pkg_actual_ver = _is_package_available (pkg )
284+ if not pkg_available :
285+ raise ValueError (
286+ f"{ pkg } was specified in the requirements but wasn't found. Please check your environment."
287+ )
288+ if specified_ver != pkg_actual_ver :
289+ logger .warning (
290+ f"Version for { pkg } was specified to be { specified_ver } whereas the actual version found is { pkg_actual_ver } . Ignore if this is not concerning."
291+ )
292+ final_reqs .append ((pkg , specified_ver ))
293+
294+ else :
295+ return None
296+
273297 @property
274298 def required_inputs (self ) -> List [InputParam ]:
275299 return self ._get_required_inputs ()
@@ -293,6 +317,31 @@ def from_pretrained(
293317 trust_remote_code : Optional [bool ] = None ,
294318 ** kwargs ,
295319 ):
320+ config = cls .load_config (pretrained_model_name_or_path )
321+ has_remote_code = "auto_map" in config and cls .__name__ in config ["auto_map" ]
322+ trust_remote_code = resolve_trust_remote_code (
323+ trust_remote_code , pretrained_model_name_or_path , has_remote_code
324+ )
325+ if not (has_remote_code and trust_remote_code ):
326+ raise ValueError (
327+ "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
328+ )
329+
330+ if "requirements" in config and config ["requirements" ] is not None :
331+ requirements : Union [List [Tuple [str , str ]], Tuple [str , str ]] = config ["requirements" ]
332+ if not isinstance (requirements , list ):
333+ requirements = [requirements ]
334+ for pkg , fetched_ver in requirements :
335+ pkg_available , pkg_actual_ver = _is_package_available (pkg )
336+ if not pkg_available :
337+ raise ValueError (
338+ f"{ pkg } was specified in the requirements but wasn't found in the current environment."
339+ )
340+ if fetched_ver != pkg_actual_ver :
341+ logger .warning (
342+ f"Version of { pkg } was specified to be { fetched_ver } in the configuration. However, the actual installed version if { pkg_actual_ver } . Things might work unexpected."
343+ )
344+
296345 hub_kwargs_names = [
297346 "cache_dir" ,
298347 "force_download" ,
@@ -305,16 +354,6 @@ def from_pretrained(
305354 ]
306355 hub_kwargs = {name : kwargs .pop (name ) for name in hub_kwargs_names if name in kwargs }
307356
308- config = cls .load_config (pretrained_model_name_or_path )
309- has_remote_code = "auto_map" in config and cls .__name__ in config ["auto_map" ]
310- trust_remote_code = resolve_trust_remote_code (
311- trust_remote_code , pretrained_model_name_or_path , has_remote_code
312- )
313- if not (has_remote_code and trust_remote_code ):
314- raise ValueError (
315- "Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
316- )
317-
318357 class_ref = config ["auto_map" ][cls .__name__ ]
319358 module_file , class_name = class_ref .split ("." )
320359 module_file = module_file + ".py"
@@ -340,8 +379,13 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
340379 module = full_mod .rsplit ("." , 1 )[- 1 ].replace ("__dynamic__" , "" )
341380 parent_module = self .save_pretrained .__func__ .__qualname__ .split ("." , 1 )[0 ]
342381 auto_map = {f"{ parent_module } " : f"{ module } .{ cls_name } " }
343-
344382 self .register_to_config (auto_map = auto_map )
383+
384+ # resolve requirements
385+ requirements = self ._get_requirements ()
386+ if requirements is not None :
387+ self .register_to_config (requirements = requirements )
388+
345389 self .save_config (save_directory = save_directory , push_to_hub = push_to_hub , ** kwargs )
346390 config = dict (self .config )
347391 self ._internal_dict = FrozenDict (config )
0 commit comments