@@ -272,29 +272,6 @@ def _get_required_inputs(self):
272272
273273 return input_names
274274
275- def _get_requirements (self ):
276- if getattr (self , "_requirements" , None ) is not None :
277- defined_reqs = self ._requirements
278- if not isinstance (defined_reqs , list ):
279- defined_reqs = [defined_reqs ]
280-
281- final_reqs = []
282- for pkg , specified_ver in defined_reqs :
283- pkg_available , pkg_actual_ver = _is_package_available (pkg )
284- if not pkg_available :
285- raise ValueError (
286- f"{ pkg } was specified in the requirements but wasn't found. Please check your environment."
287- )
288- if specified_ver != pkg_actual_ver :
289- logger .warning (
290- f"Version for { pkg } was specified to be { specified_ver } whereas the actual version found is { pkg_actual_ver } . Ignore if this is not concerning."
291- )
292- final_reqs .append ((pkg , specified_ver ))
293- return final_reqs
294-
295- else :
296- return None
297-
298275 @property
299276 def required_inputs (self ) -> List [InputParam ]:
300277 return self ._get_required_inputs ()
@@ -329,19 +306,7 @@ def from_pretrained(
329306 )
330307
331308 if "requirements" in config and config ["requirements" ] is not None :
332- requirements : Union [List [Tuple [str , str ]], Tuple [str , str ]] = config ["requirements" ]
333- if not isinstance (requirements , list ):
334- requirements = [requirements ]
335- for pkg , fetched_ver in requirements :
336- pkg_available , pkg_actual_ver = _is_package_available (pkg )
337- if not pkg_available :
338- raise ValueError (
339- f"{ pkg } was specified in the requirements but wasn't found in the current environment."
340- )
341- if fetched_ver != pkg_actual_ver :
342- logger .warning (
343- f"Version of { pkg } was specified to be { fetched_ver } in the configuration. However, the actual installed version if { pkg_actual_ver } . Things might work unexpected."
344- )
309+ _ = _validate_requirements (config ["requirements" ])
345310
346311 hub_kwargs_names = [
347312 "cache_dir" ,
@@ -383,8 +348,8 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
383348 self .register_to_config (auto_map = auto_map )
384349
385350 # resolve requirements
386- requirements = self . _get_requirements ( )
387- if requirements is not None :
351+ requirements = _validate_requirements ( getattr ( self , "_requirements" , None ) )
352+ if requirements :
388353 self .register_to_config (requirements = requirements )
389354
390355 self .save_config (save_directory = save_directory , push_to_hub = push_to_hub , ** kwargs )
@@ -2489,3 +2454,33 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
24892454 return state .get (output )
24902455 else :
24912456 raise ValueError (f"Output '{ output } ' is not a valid output type" )
2457+
2458+
2459+ def _validate_requirements (reqs ):
2460+ normalized_reqs = _normalize_requirements (reqs )
2461+ if not normalized_reqs :
2462+ return []
2463+
2464+ final : List [Tuple [str , str ]] = []
2465+ for req , specified_ver in normalized_reqs :
2466+ req_available , req_actual_ver = _is_package_available (req )
2467+ if not req_available :
2468+ raise ValueError (f"{ req } was specified in the requirements but wasn't found in the current environment." )
2469+ if specified_ver != req_actual_ver :
2470+ logger .warning (
2471+ f"Version of { req } was specified to be { specified_ver } in the configuration. However, the actual installed version if { req_actual_ver } . Things might work unexpected."
2472+ )
2473+
2474+ final .append ((req , specified_ver ))
2475+
2476+ return final
2477+
2478+
2479+ def _normalize_requirements (reqs ):
2480+ if not reqs :
2481+ return []
2482+ if isinstance (reqs , tuple ) and len (reqs ) == 2 and isinstance (reqs [0 ], str ):
2483+ req_seq : List [Tuple [str , str ]] = [reqs ] # single pair
2484+ else :
2485+ req_seq = reqs
2486+ return req_seq
0 commit comments