@@ -272,29 +272,6 @@ def _get_required_inputs(self):
272272
273273        return  input_names 
274274
275-     def  _get_requirements (self ):
276-         if  getattr (self , "_requirements" , None ) is  not None :
277-             defined_reqs  =  self ._requirements 
278-             if  not  isinstance (defined_reqs , list ):
279-                 defined_reqs  =  [defined_reqs ]
280- 
281-             final_reqs  =  []
282-             for  pkg , specified_ver  in  defined_reqs :
283-                 pkg_available , pkg_actual_ver  =  _is_package_available (pkg )
284-                 if  not  pkg_available :
285-                     raise  ValueError (
286-                         f"{ pkg }  
287-                     )
288-                 if  specified_ver  !=  pkg_actual_ver :
289-                     logger .warning (
290-                         f"Version for { pkg } { specified_ver } { pkg_actual_ver }  
291-                     )
292-                 final_reqs .append ((pkg , specified_ver ))
293-             return  final_reqs 
294- 
295-         else :
296-             return  None 
297- 
298275    @property  
299276    def  required_inputs (self ) ->  List [InputParam ]:
300277        return  self ._get_required_inputs ()
@@ -329,19 +306,7 @@ def from_pretrained(
329306            )
330307
331308        if  "requirements"  in  config  and  config ["requirements" ] is  not None :
332-             requirements : Union [List [Tuple [str , str ]], Tuple [str , str ]] =  config ["requirements" ]
333-             if  not  isinstance (requirements , list ):
334-                 requirements  =  [requirements ]
335-             for  pkg , fetched_ver  in  requirements :
336-                 pkg_available , pkg_actual_ver  =  _is_package_available (pkg )
337-                 if  not  pkg_available :
338-                     raise  ValueError (
339-                         f"{ pkg }  
340-                     )
341-                 if  fetched_ver  !=  pkg_actual_ver :
342-                     logger .warning (
343-                         f"Version of { pkg } { fetched_ver } { pkg_actual_ver }  
344-                     )
309+             _  =  _validate_requirements (config ["requirements" ])
345310
346311        hub_kwargs_names  =  [
347312            "cache_dir" ,
@@ -383,8 +348,8 @@ def save_pretrained(self, save_directory, push_to_hub=False, **kwargs):
383348        self .register_to_config (auto_map = auto_map )
384349
385350        # resolve requirements 
386-         requirements  =  self . _get_requirements ( )
387-         if  requirements   is   not   None :
351+         requirements  =  _validate_requirements ( getattr ( self ,  "_requirements" ,  None ) )
352+         if  requirements :
388353            self .register_to_config (requirements = requirements )
389354
390355        self .save_config (save_directory = save_directory , push_to_hub = push_to_hub , ** kwargs )
@@ -2489,3 +2454,33 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
24892454            return  state .get (output )
24902455        else :
24912456            raise  ValueError (f"Output '{ output }  )
2457+ 
2458+ 
2459+ def  _validate_requirements (reqs ):
2460+     normalized_reqs  =  _normalize_requirements (reqs )
2461+     if  not  normalized_reqs :
2462+         return  []
2463+ 
2464+     final : List [Tuple [str , str ]] =  []
2465+     for  req , specified_ver  in  normalized_reqs :
2466+         req_available , req_actual_ver  =  _is_package_available (req )
2467+         if  not  req_available :
2468+             raise  ValueError (f"{ req }  )
2469+         if  specified_ver  !=  req_actual_ver :
2470+             logger .warning (
2471+                 f"Version of { req } { specified_ver } { req_actual_ver }  
2472+             )
2473+ 
2474+         final .append ((req , specified_ver ))
2475+ 
2476+     return  final 
2477+ 
2478+ 
2479+ def  _normalize_requirements (reqs ):
2480+     if  not  reqs :
2481+         return  []
2482+     if  isinstance (reqs , tuple ) and  len (reqs ) ==  2  and  isinstance (reqs [0 ], str ):
2483+         req_seq : List [Tuple [str , str ]] =  [reqs ]  # single pair 
2484+     else :
2485+         req_seq  =  reqs 
2486+     return  req_seq 
0 commit comments