2525from typing import Any , Union
2626
2727from huggingface_hub .utils import is_jinja_available # noqa: F401
28- from packaging import version
2928from packaging .version import Version , parse
3029
3130from . import logging
5251
5352STR_OPERATION_TO_FUNC = {">" : op .gt , ">=" : op .ge , "==" : op .eq , "!=" : op .ne , "<=" : op .le , "<" : op .lt }
5453
55- _torch_version = "N/A"
56- if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES :
57- _torch_available = importlib .util .find_spec ("torch" ) is not None
58- if _torch_available :
54+ _is_google_colab = "google.colab" in sys .modules or any (k .startswith ("COLAB_" ) for k in os .environ )
55+
56+
57+ def _is_package_available (pkg_name : str ):
58+ pkg_exists = importlib .util .find_spec (pkg_name ) is not None
59+ pkg_version = "N/A"
60+
61+ if pkg_exists :
5962 try :
60- _torch_version = importlib_metadata .version ("torch" )
61- logger .info (f"PyTorch version { _torch_version } available." )
62- except importlib_metadata .PackageNotFoundError :
63- _torch_available = False
63+ pkg_version = importlib_metadata .version (pkg_name )
64+ logger .debug (f"Successfully imported { pkg_name } version { pkg_version } " )
65+ except (ImportError , importlib_metadata .PackageNotFoundError ):
66+ pkg_exists = False
67+
68+ return pkg_exists , pkg_version
69+
70+
71+ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES :
72+ _torch_available , _torch_version = _is_package_available ("torch" )
73+
6474else :
6575 logger .info ("Disabling PyTorch because USE_TORCH is set" )
6676 _torch_available = False
6777
68- _torch_xla_available = importlib .util .find_spec ("torch_xla" ) is not None
69- if _torch_xla_available :
70- try :
71- _torch_xla_version = importlib_metadata .version ("torch_xla" )
72- logger .info (f"PyTorch XLA version { _torch_xla_version } available." )
73- except ImportError :
74- _torch_xla_available = False
75-
76- # check whether torch_npu is available
77- _torch_npu_available = importlib .util .find_spec ("torch_npu" ) is not None
78- if _torch_npu_available :
79- try :
80- _torch_npu_version = importlib_metadata .version ("torch_npu" )
81- logger .info (f"torch_npu version { _torch_npu_version } available." )
82- except ImportError :
83- _torch_npu_available = False
84-
8578_jax_version = "N/A"
8679_flax_version = "N/A"
8780if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES :
9790 _flax_available = False
9891
9992if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES :
100- _safetensors_available = importlib .util .find_spec ("safetensors" ) is not None
101- if _safetensors_available :
102- try :
103- _safetensors_version = importlib_metadata .version ("safetensors" )
104- logger .info (f"Safetensors version { _safetensors_version } available." )
105- except importlib_metadata .PackageNotFoundError :
106- _safetensors_available = False
93+ _safetensors_available , _safetensors_version = _is_package_available ("safetensors" )
94+
10795else :
10896 logger .info ("Disabling Safetensors because USE_TF is set" )
10997 _safetensors_available = False
11098
111- _transformers_available = importlib .util .find_spec ("transformers" ) is not None
112- try :
113- _transformers_version = importlib_metadata .version ("transformers" )
114- logger .debug (f"Successfully imported transformers version { _transformers_version } " )
115- except importlib_metadata .PackageNotFoundError :
116- _transformers_available = False
117-
118- _hf_hub_available = importlib .util .find_spec ("huggingface_hub" ) is not None
119- try :
120- _hf_hub_version = importlib_metadata .version ("huggingface_hub" )
121- logger .debug (f"Successfully imported huggingface_hub version { _hf_hub_version } " )
122- except importlib_metadata .PackageNotFoundError :
123- _hf_hub_available = False
124-
125-
126- _inflect_available = importlib .util .find_spec ("inflect" ) is not None
127- try :
128- _inflect_version = importlib_metadata .version ("inflect" )
129- logger .debug (f"Successfully imported inflect version { _inflect_version } " )
130- except importlib_metadata .PackageNotFoundError :
131- _inflect_available = False
132-
133-
134- _unidecode_available = importlib .util .find_spec ("unidecode" ) is not None
135- try :
136- _unidecode_version = importlib_metadata .version ("unidecode" )
137- logger .debug (f"Successfully imported unidecode version { _unidecode_version } " )
138- except importlib_metadata .PackageNotFoundError :
139- _unidecode_available = False
140-
14199_onnxruntime_version = "N/A"
142100_onnx_available = importlib .util .find_spec ("onnxruntime" ) is not None
143101if _onnx_available :
186144except importlib_metadata .PackageNotFoundError :
187145 _opencv_available = False
188146
189- _scipy_available = importlib .util .find_spec ("scipy" ) is not None
190- try :
191- _scipy_version = importlib_metadata .version ("scipy" )
192- logger .debug (f"Successfully imported scipy version { _scipy_version } " )
193- except importlib_metadata .PackageNotFoundError :
194- _scipy_available = False
195-
196- _librosa_available = importlib .util .find_spec ("librosa" ) is not None
197- try :
198- _librosa_version = importlib_metadata .version ("librosa" )
199- logger .debug (f"Successfully imported librosa version { _librosa_version } " )
200- except importlib_metadata .PackageNotFoundError :
201- _librosa_available = False
202-
203- _accelerate_available = importlib .util .find_spec ("accelerate" ) is not None
204- try :
205- _accelerate_version = importlib_metadata .version ("accelerate" )
206- logger .debug (f"Successfully imported accelerate version { _accelerate_version } " )
207- except importlib_metadata .PackageNotFoundError :
208- _accelerate_available = False
209-
210- _xformers_available = importlib .util .find_spec ("xformers" ) is not None
211- try :
212- _xformers_version = importlib_metadata .version ("xformers" )
213- if _torch_available :
214- _torch_version = importlib_metadata .version ("torch" )
215- if version .Version (_torch_version ) < version .Version ("1.12" ):
216- raise ValueError ("xformers is installed in your environment and requires PyTorch >= 1.12" )
217-
218- logger .debug (f"Successfully imported xformers version { _xformers_version } " )
219- except importlib_metadata .PackageNotFoundError :
220- _xformers_available = False
221-
222- _k_diffusion_available = importlib .util .find_spec ("k_diffusion" ) is not None
223- try :
224- _k_diffusion_version = importlib_metadata .version ("k_diffusion" )
225- logger .debug (f"Successfully imported k-diffusion version { _k_diffusion_version } " )
226- except importlib_metadata .PackageNotFoundError :
227- _k_diffusion_available = False
228-
229- _note_seq_available = importlib .util .find_spec ("note_seq" ) is not None
230- try :
231- _note_seq_version = importlib_metadata .version ("note_seq" )
232- logger .debug (f"Successfully imported note-seq version { _note_seq_version } " )
233- except importlib_metadata .PackageNotFoundError :
234- _note_seq_available = False
235-
236- _wandb_available = importlib .util .find_spec ("wandb" ) is not None
237- try :
238- _wandb_version = importlib_metadata .version ("wandb" )
239- logger .debug (f"Successfully imported wandb version { _wandb_version } " )
240- except importlib_metadata .PackageNotFoundError :
241- _wandb_available = False
242-
243-
244- _tensorboard_available = importlib .util .find_spec ("tensorboard" )
245- try :
246- _tensorboard_version = importlib_metadata .version ("tensorboard" )
247- logger .debug (f"Successfully imported tensorboard version { _tensorboard_version } " )
248- except importlib_metadata .PackageNotFoundError :
249- _tensorboard_available = False
250-
251-
252- _compel_available = importlib .util .find_spec ("compel" )
253- try :
254- _compel_version = importlib_metadata .version ("compel" )
255- logger .debug (f"Successfully imported compel version { _compel_version } " )
256- except importlib_metadata .PackageNotFoundError :
257- _compel_available = False
258-
259-
260- _ftfy_available = importlib .util .find_spec ("ftfy" ) is not None
261- try :
262- _ftfy_version = importlib_metadata .version ("ftfy" )
263- logger .debug (f"Successfully imported ftfy version { _ftfy_version } " )
264- except importlib_metadata .PackageNotFoundError :
265- _ftfy_available = False
266-
267-
268147_bs4_available = importlib .util .find_spec ("bs4" ) is not None
269148try :
270149 # importlib metadata under different name
273152except importlib_metadata .PackageNotFoundError :
274153 _bs4_available = False
275154
276- _torchsde_available = importlib .util .find_spec ("torchsde" ) is not None
277- try :
278- _torchsde_version = importlib_metadata .version ("torchsde" )
279- logger .debug (f"Successfully imported torchsde version { _torchsde_version } " )
280- except importlib_metadata .PackageNotFoundError :
281- _torchsde_available = False
282-
283155_invisible_watermark_available = importlib .util .find_spec ("imwatermark" ) is not None
284156try :
285157 _invisible_watermark_version = importlib_metadata .version ("invisible-watermark" )
286158 logger .debug (f"Successfully imported invisible-watermark version { _invisible_watermark_version } " )
287159except importlib_metadata .PackageNotFoundError :
288160 _invisible_watermark_available = False
289161
290-
291- _peft_available = importlib .util .find_spec ("peft" ) is not None
292- try :
293- _peft_version = importlib_metadata .version ("peft" )
294- logger .debug (f"Successfully imported peft version { _peft_version } " )
295- except importlib_metadata .PackageNotFoundError :
296- _peft_available = False
297-
298- _torchvision_available = importlib .util .find_spec ("torchvision" ) is not None
299- try :
300- _torchvision_version = importlib_metadata .version ("torchvision" )
301- logger .debug (f"Successfully imported torchvision version { _torchvision_version } " )
302- except importlib_metadata .PackageNotFoundError :
303- _torchvision_available = False
304-
305- _sentencepiece_available = importlib .util .find_spec ("sentencepiece" ) is not None
306- try :
307- _sentencepiece_version = importlib_metadata .version ("sentencepiece" )
308- logger .info (f"Successfully imported sentencepiece version { _sentencepiece_version } " )
309- except importlib_metadata .PackageNotFoundError :
310- _sentencepiece_available = False
311-
312- _matplotlib_available = importlib .util .find_spec ("matplotlib" ) is not None
313- try :
314- _matplotlib_version = importlib_metadata .version ("matplotlib" )
315- logger .debug (f"Successfully imported matplotlib version { _matplotlib_version } " )
316- except importlib_metadata .PackageNotFoundError :
317- _matplotlib_available = False
318-
319- _timm_available = importlib .util .find_spec ("timm" ) is not None
320- if _timm_available :
321- try :
322- _timm_version = importlib_metadata .version ("timm" )
323- logger .info (f"Timm version { _timm_version } available." )
324- except importlib_metadata .PackageNotFoundError :
325- _timm_available = False
326-
327-
328- def is_timm_available ():
329- return _timm_available
330-
331-
332- _bitsandbytes_available = importlib .util .find_spec ("bitsandbytes" ) is not None
333- try :
334- _bitsandbytes_version = importlib_metadata .version ("bitsandbytes" )
335- logger .debug (f"Successfully imported bitsandbytes version { _bitsandbytes_version } " )
336- except importlib_metadata .PackageNotFoundError :
337- _bitsandbytes_available = False
338-
339- _is_google_colab = "google.colab" in sys .modules or any (k .startswith ("COLAB_" ) for k in os .environ )
340-
341- _imageio_available = importlib .util .find_spec ("imageio" ) is not None
342- if _imageio_available :
343- try :
344- _imageio_version = importlib_metadata .version ("imageio" )
345- logger .debug (f"Successfully imported imageio version { _imageio_version } " )
346-
347- except importlib_metadata .PackageNotFoundError :
348- _imageio_available = False
349-
350- _is_gguf_available = importlib .util .find_spec ("gguf" ) is not None
351- if _is_gguf_available :
352- try :
353- _gguf_version = importlib_metadata .version ("gguf" )
354- logger .debug (f"Successfully import gguf version { _gguf_version } " )
355- except importlib_metadata .PackageNotFoundError :
356- _is_gguf_available = False
357-
358-
359- _is_torchao_available = importlib .util .find_spec ("torchao" ) is not None
360- if _is_torchao_available :
361- try :
362- _torchao_version = importlib_metadata .version ("torchao" )
363- logger .debug (f"Successfully import torchao version { _torchao_version } " )
364- except importlib_metadata .PackageNotFoundError :
365- _is_torchao_available = False
366-
367-
368- _is_optimum_quanto_available = importlib .util .find_spec ("optimum" ) is not None
369- if _is_optimum_quanto_available :
162+ _torch_xla_available , _torch_xla_version = _is_package_available ("torch_xla" )
163+ _torch_npu_available , _torch_npu_version = _is_package_available ("torch_npu" )
164+ _transformers_available , _transformers_version = _is_package_available ("transformers" )
165+ _hf_hub_available , _hf_hub_version = _is_package_available ("huggingface_hub" )
166+ _inflect_available , _inflect_version = _is_package_available ("inflect" )
167+ _unidecode_available , _unidecode_version = _is_package_available ("unidecode" )
168+ _k_diffusion_available , _k_diffusion_version = _is_package_available ("k_diffusion" )
169+ _note_seq_available , _note_seq_version = _is_package_available ("note_seq" )
170+ _wandb_available , _wandb_version = _is_package_available ("wandb" )
171+ _tensorboard_available , _tensorboard_version = _is_package_available ("tensorboard" )
172+ _compel_available , _compel_version = _is_package_available ("compel" )
173+ _sentencepiece_available , _sentencepiece_version = _is_package_available ("sentencepiece" )
174+ _torchsde_available , _torchsde_version = _is_package_available ("torchsde" )
175+ _peft_available , _peft_version = _is_package_available ("peft" )
176+ _torchvision_available , _torchvision_version = _is_package_available ("torchvision" )
177+ _matplotlib_available , _matplotlib_version = _is_package_available ("matplotlib" )
178+ _timm_available , _timm_version = _is_package_available ("timm" )
179+ _bitsandbytes_available , _bitsandbytes_version = _is_package_available ("bitsandbytes" )
180+ _imageio_available , _imageio_version = _is_package_available ("imageio" )
181+ _ftfy_available , _ftfy_version = _is_package_available ("ftfy" )
182+ _scipy_available , _scipy_version = _is_package_available ("scipy" )
183+ _librosa_available , _librosa_version = _is_package_available ("librosa" )
184+ _accelerate_available , _accelerate_version = _is_package_available ("accelerate" )
185+ _xformers_available , _xformers_version = _is_package_available ("xformers" )
186+ _gguf_available , _gguf_version = _is_package_available ("gguf" )
187+ _torchao_available , _torchao_version = _is_package_available ("torchao" )
188+ _bitsandbytes_available , _bitsandbytes_version = _is_package_available ("bitsandbytes" )
189+ _torchao_available , _torchao_version = _is_package_available ("torchao" )
190+
191+ _optimum_quanto_available = importlib .util .find_spec ("optimum" ) is not None
192+ if _optimum_quanto_available :
370193 try :
371194 _optimum_quanto_version = importlib_metadata .version ("optimum_quanto" )
372195 logger .debug (f"Successfully import optimum-quanto version { _optimum_quanto_version } " )
373196 except importlib_metadata .PackageNotFoundError :
374- _is_optimum_quanto_available = False
197+ _optimum_quanto_available = False
375198
376199
377200def is_torch_available ():
@@ -495,15 +318,19 @@ def is_imageio_available():
495318
496319
497320def is_gguf_available ():
498- return _is_gguf_available
321+ return _gguf_available
499322
500323
501324def is_torchao_available ():
502- return _is_torchao_available
325+ return _torchao_available
503326
504327
505328def is_optimum_quanto_available ():
506- return _is_optimum_quanto_available
329+ return _optimum_quanto_available
330+
331+
332+ def is_timm_available ():
333+ return _timm_available
507334
508335
509336# docstyle-ignore
@@ -863,7 +690,7 @@ def is_gguf_version(operation: str, version: str):
863690 version (`str`):
864691 A version string
865692 """
866- if not _is_gguf_available :
693+ if not _gguf_available :
867694 return False
868695 return compare_versions (parse (_gguf_version ), operation , version )
869696
@@ -878,7 +705,7 @@ def is_torchao_version(operation: str, version: str):
878705 version (`str`):
879706 A version string
880707 """
881- if not _is_torchao_available :
708+ if not _torchao_available :
882709 return False
883710 return compare_versions (parse (_torchao_version ), operation , version )
884711
@@ -908,7 +735,7 @@ def is_optimum_quanto_version(operation: str, version: str):
908735 version (`str`):
909736 A version string
910737 """
911- if not _is_optimum_quanto_available :
738+ if not _optimum_quanto_available :
912739 return False
913740 return compare_versions (parse (_optimum_quanto_version ), operation , version )
914741
0 commit comments