3737else :
3838    import  importlib .metadata  as  importlib_metadata 
3939try :
40-     _package_map  =  importlib_metadata .packages_distributions ()  # load-once to avoid expensive calls 
40+     _package_map  =  (
41+         importlib_metadata .packages_distributions ()
42+     )  # load-once to avoid expensive calls 
4143except  Exception :
4244    _package_map  =  None 
4345
5355DIFFUSERS_SLOW_IMPORT  =  os .environ .get ("DIFFUSERS_SLOW_IMPORT" , "FALSE" ).upper ()
5456DIFFUSERS_SLOW_IMPORT  =  DIFFUSERS_SLOW_IMPORT  in  ENV_VARS_TRUE_VALUES 
5557
56- STR_OPERATION_TO_FUNC  =  {">" : op .gt , ">=" : op .ge , "==" : op .eq , "!=" : op .ne , "<=" : op .le , "<" : op .lt }
57- 
58- _is_google_colab  =  "google.colab"  in  sys .modules  or  any (k .startswith ("COLAB_" ) for  k  in  os .environ )
58+ STR_OPERATION_TO_FUNC  =  {
59+     ">" : op .gt ,
60+     ">=" : op .ge ,
61+     "==" : op .eq ,
62+     "!=" : op .ne ,
63+     "<=" : op .le ,
64+     "<" : op .lt ,
65+ }
66+ 
67+ _is_google_colab  =  "google.colab"  in  sys .modules  or  any (
68+     k .startswith ("COLAB_" ) for  k  in  os .environ 
69+ )
5970
6071
61- def  _is_package_available (pkg_name : str , get_dist_name : bool  =  False ) ->  Tuple [bool , str ]:
72+ def  _is_package_available (
73+     pkg_name : str , get_dist_name : bool  =  False 
74+ ) ->  Tuple [bool , str ]:
6275    global  _package_map 
6376    pkg_exists  =  importlib .util .find_spec (pkg_name ) is  not None 
6477    pkg_version  =  "N/A" 
@@ -69,11 +82,16 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
6982            try :
7083                # Fallback for Python < 3.10 
7184                for  dist  in  importlib_metadata .distributions ():
72-                     _top_level_declared  =  (dist .read_text ("top_level.txt" ) or  "" ).split ()
85+                     _top_level_declared  =  (
86+                         dist .read_text ("top_level.txt" ) or  "" 
87+                     ).split ()
7388                    _infered_opt_names  =  {
74-                         f .parts [0 ] if  len (f .parts ) >  1  else  inspect .getmodulename (f ) for  f  in  (dist .files  or  [])
89+                         f .parts [0 ] if  len (f .parts ) >  1  else  inspect .getmodulename (f )
90+                         for  f  in  (dist .files  or  [])
7591                    } -  {None }
76-                     _top_level_inferred  =  filter (lambda  name : "."  not  in name , _infered_opt_names )
92+                     _top_level_inferred  =  filter (
93+                         lambda  name : "."  not  in name , _infered_opt_names 
94+                     )
7795                    for  pkg  in  _top_level_declared  or  _top_level_inferred :
7896                        _package_map [pkg ].append (dist .metadata ["Name" ])
7997            except  Exception  as  _ :
@@ -99,16 +117,22 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
99117else :
100118    logger .info ("Disabling PyTorch because USE_TORCH is set" )
101119    _torch_available  =  False 
120+     _torch_version  =  "N/A" 
102121
103122_jax_version  =  "N/A" 
104123_flax_version  =  "N/A" 
105124if  USE_JAX  in  ENV_VARS_TRUE_AND_AUTO_VALUES :
106-     _flax_available  =  importlib .util .find_spec ("jax" ) is  not None  and  importlib .util .find_spec ("flax" ) is  not None 
125+     _flax_available  =  (
126+         importlib .util .find_spec ("jax" ) is  not None 
127+         and  importlib .util .find_spec ("flax" ) is  not None 
128+     )
107129    if  _flax_available :
108130        try :
109131            _jax_version  =  importlib_metadata .version ("jax" )
110132            _flax_version  =  importlib_metadata .version ("flax" )
111-             logger .info (f"JAX version { _jax_version } { _flax_version }  )
133+             logger .info (
134+                 f"JAX version { _jax_version } { _flax_version }  
135+             )
112136        except  importlib_metadata .PackageNotFoundError :
113137            _flax_available  =  False 
114138else :
@@ -148,7 +172,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
148172            pass 
149173    _onnx_available  =  _onnxruntime_version  is  not None 
150174    if  _onnx_available :
151-         logger .debug (f"Successfully imported onnxruntime version { _onnxruntime_version }  )
175+         logger .debug (
176+             f"Successfully imported onnxruntime version { _onnxruntime_version }  
177+         )
152178
153179# (sayakpaul): importlib.util.find_spec("opencv-python") returns None even when it's installed. 
154180# _opencv_available = importlib.util.find_spec("opencv-python") is not None 
@@ -183,7 +209,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
183209_invisible_watermark_available  =  importlib .util .find_spec ("imwatermark" ) is  not None 
184210try :
185211    _invisible_watermark_version  =  importlib_metadata .version ("invisible-watermark" )
186-     logger .debug (f"Successfully imported invisible-watermark version { _invisible_watermark_version }  )
212+     logger .debug (
213+         f"Successfully imported invisible-watermark version { _invisible_watermark_version }  
214+     )
187215except  importlib_metadata .PackageNotFoundError :
188216    _invisible_watermark_available  =  False 
189217
@@ -198,7 +226,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
198226_wandb_available , _wandb_version  =  _is_package_available ("wandb" )
199227_tensorboard_available , _tensorboard_version  =  _is_package_available ("tensorboard" )
200228_compel_available , _compel_version  =  _is_package_available ("compel" )
201- _sentencepiece_available , _sentencepiece_version  =  _is_package_available ("sentencepiece" )
229+ _sentencepiece_available , _sentencepiece_version  =  _is_package_available (
230+     "sentencepiece" 
231+ )
202232_torchsde_available , _torchsde_version  =  _is_package_available ("torchsde" )
203233_peft_available , _peft_version  =  _is_package_available ("peft" )
204234_torchvision_available , _torchvision_version  =  _is_package_available ("torchvision" )
@@ -214,11 +244,19 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b
214244_gguf_available , _gguf_version  =  _is_package_available ("gguf" )
215245_torchao_available , _torchao_version  =  _is_package_available ("torchao" )
216246_bitsandbytes_available , _bitsandbytes_version  =  _is_package_available ("bitsandbytes" )
217- _optimum_quanto_available , _optimum_quanto_version  =  _is_package_available ("optimum" , get_dist_name = True )
218- _pytorch_retinaface_available , _pytorch_retinaface_version  =  _is_package_available ("pytorch_retinaface" )
219- _better_profanity_available , _better_profanity_version  =  _is_package_available ("better_profanity" )
247+ _optimum_quanto_available , _optimum_quanto_version  =  _is_package_available (
248+     "optimum" , get_dist_name = True 
249+ )
250+ _pytorch_retinaface_available , _pytorch_retinaface_version  =  _is_package_available (
251+     "pytorch_retinaface" 
252+ )
253+ _better_profanity_available , _better_profanity_version  =  _is_package_available (
254+     "better_profanity" 
255+ )
220256_nltk_available , _nltk_version  =  _is_package_available ("nltk" )
221- _cosmos_guardrail_available , _cosmos_guardrail_version  =  _is_package_available ("cosmos_guardrail" )
257+ _cosmos_guardrail_available , _cosmos_guardrail_version  =  _is_package_available (
258+     "cosmos_guardrail" 
259+ )
222260
223261
224262def  is_torch_available ():
@@ -374,7 +412,10 @@ def is_cosmos_guardrail_available():
374412
375413
376414def  is_hpu_available ():
377-     return  all (importlib .util .find_spec (lib ) for  lib  in  ("habana_frameworks" , "habana_frameworks.torch" ))
415+     return  all (
416+         importlib .util .find_spec (lib )
417+         for  lib  in  ("habana_frameworks" , "habana_frameworks.torch" )
418+     )
378419
379420
380421# docstyle-ignore 
@@ -560,7 +601,10 @@ def is_hpu_available():
560601        ("compel" , (is_compel_available , COMPEL_IMPORT_ERROR )),
561602        ("ftfy" , (is_ftfy_available , FTFY_IMPORT_ERROR )),
562603        ("torchsde" , (is_torchsde_available , TORCHSDE_IMPORT_ERROR )),
563-         ("invisible_watermark" , (is_invisible_watermark_available , INVISIBLE_WATERMARK_IMPORT_ERROR )),
604+         (
605+             "invisible_watermark" ,
606+             (is_invisible_watermark_available , INVISIBLE_WATERMARK_IMPORT_ERROR ),
607+         ),
564608        ("peft" , (is_peft_available , PEFT_IMPORT_ERROR )),
565609        ("safetensors" , (is_safetensors_available , SAFETENSORS_IMPORT_ERROR )),
566610        ("bitsandbytes" , (is_bitsandbytes_available , BITSANDBYTES_IMPORT_ERROR )),
@@ -569,8 +613,14 @@ def is_hpu_available():
569613        ("gguf" , (is_gguf_available , GGUF_IMPORT_ERROR )),
570614        ("torchao" , (is_torchao_available , TORCHAO_IMPORT_ERROR )),
571615        ("quanto" , (is_optimum_quanto_available , QUANTO_IMPORT_ERROR )),
572-         ("pytorch_retinaface" , (is_pytorch_retinaface_available , PYTORCH_RETINAFACE_IMPORT_ERROR )),
573-         ("better_profanity" , (is_better_profanity_available , BETTER_PROFANITY_IMPORT_ERROR )),
616+         (
617+             "pytorch_retinaface" ,
618+             (is_pytorch_retinaface_available , PYTORCH_RETINAFACE_IMPORT_ERROR ),
619+         ),
620+         (
621+             "better_profanity" ,
622+             (is_better_profanity_available , BETTER_PROFANITY_IMPORT_ERROR ),
623+         ),
574624        ("nltk" , (is_nltk_available , NLTK_IMPORT_ERROR )),
575625    ]
576626)
@@ -598,9 +648,10 @@ def requires_backends(obj, backends):
598648            " --upgrade transformers \n ```" 
599649        )
600650
601-     if  name  in  ["StableDiffusionDepth2ImgPipeline" , "StableDiffusionPix2PixZeroPipeline" ] and  is_transformers_version (
602-         "<" , "4.26.0" 
603-     ):
651+     if  name  in  [
652+         "StableDiffusionDepth2ImgPipeline" ,
653+         "StableDiffusionPix2PixZeroPipeline" ,
654+     ] and  is_transformers_version ("<" , "4.26.0" ):
604655        raise  ImportError (
605656            f"You need to install `transformers>=4.26` in order to use { name } \n ```\n  pip install" 
606657            " --upgrade transformers \n ```" 
@@ -620,7 +671,9 @@ def __getattr__(cls, key):
620671
621672
622673# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319 
623- def  compare_versions (library_or_version : Union [str , Version ], operation : str , requirement_version : str ):
674+ def  compare_versions (
675+     library_or_version : Union [str , Version ], operation : str , requirement_version : str 
676+ ):
624677    """ 
625678    Compares a library version to some requirement using a given operation. 
626679
@@ -633,7 +686,9 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
633686            The version to compare the library version against 
634687    """ 
635688    if  operation  not  in STR_OPERATION_TO_FUNC .keys ():
636-         raise  ValueError (f"`operation` must be one of { list (STR_OPERATION_TO_FUNC .keys ())} { operation }  )
689+         raise  ValueError (
690+             f"`operation` must be one of { list (STR_OPERATION_TO_FUNC .keys ())} { operation }  
691+         )
637692    operation  =  STR_OPERATION_TO_FUNC [operation ]
638693    if  isinstance (library_or_version , str ):
639694        library_or_version  =  parse (importlib_metadata .version (library_or_version ))
@@ -837,15 +892,19 @@ class _LazyModule(ModuleType):
837892
838893    # Very heavily inspired by optuna.integration._IntegrationModule 
839894    # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py 
840-     def  __init__ (self , name , module_file , import_structure , module_spec = None , extra_objects = None ):
895+     def  __init__ (
896+         self , name , module_file , import_structure , module_spec = None , extra_objects = None 
897+     ):
841898        super ().__init__ (name )
842899        self ._modules  =  set (import_structure .keys ())
843900        self ._class_to_module  =  {}
844901        for  key , values  in  import_structure .items ():
845902            for  value  in  values :
846903                self ._class_to_module [value ] =  key 
847904        # Needed for autocompletion in an IDE 
848-         self .__all__  =  list (import_structure .keys ()) +  list (chain (* import_structure .values ()))
905+         self .__all__  =  list (import_structure .keys ()) +  list (
906+             chain (* import_structure .values ())
907+         )
849908        self .__file__  =  module_file 
850909        self .__spec__  =  module_spec 
851910        self .__path__  =  [os .path .dirname (module_file )]
0 commit comments