|
25 | 25 | from typing import Any, Union |
26 | 26 |
|
27 | 27 | from huggingface_hub.utils import is_jinja_available # noqa: F401 |
28 | | -from packaging import version |
29 | 28 | from packaging.version import Version, parse |
30 | 29 |
|
31 | 30 | from . import logging |
|
52 | 51 |
|
53 | 52 | STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} |
54 | 53 |
|
55 | | -_torch_version = "N/A" |
56 | | -if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: |
57 | | - _torch_available = importlib.util.find_spec("torch") is not None |
58 | | - if _torch_available: |
| 54 | +_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) |
| 55 | + |
| 56 | + |
| 57 | +def _is_package_available(pkg_name: str): |
| 58 | + pkg_exists = importlib.util.find_spec(pkg_name) is not None |
| 59 | + pkg_version = "N/A" |
| 60 | + |
| 61 | + if pkg_exists: |
59 | 62 | try: |
60 | | - _torch_version = importlib_metadata.version("torch") |
61 | | - logger.info(f"PyTorch version {_torch_version} available.") |
62 | | - except importlib_metadata.PackageNotFoundError: |
63 | | - _torch_available = False |
| 63 | + pkg_version = importlib_metadata.version(pkg_name) |
| 64 | + logger.debug(f"Successfully imported {pkg_name} version {pkg_version}") |
| 65 | + except (ImportError, importlib_metadata.PackageNotFoundError): |
| 66 | + pkg_exists = False |
| 67 | + |
| 68 | + return pkg_exists, pkg_version |
| 69 | + |
| 70 | + |
| 71 | +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: |
| 72 | + _torch_available, _torch_version = _is_package_available("torch") |
| 73 | + |
64 | 74 | else: |
65 | 75 | logger.info("Disabling PyTorch because USE_TORCH is set") |
66 | 76 | _torch_available = False |
67 | 77 |
|
68 | | -_torch_xla_available = importlib.util.find_spec("torch_xla") is not None |
69 | | -if _torch_xla_available: |
70 | | - try: |
71 | | - _torch_xla_version = importlib_metadata.version("torch_xla") |
72 | | - logger.info(f"PyTorch XLA version {_torch_xla_version} available.") |
73 | | - except ImportError: |
74 | | - _torch_xla_available = False |
75 | | - |
76 | | -# check whether torch_npu is available |
77 | | -_torch_npu_available = importlib.util.find_spec("torch_npu") is not None |
78 | | -if _torch_npu_available: |
79 | | - try: |
80 | | - _torch_npu_version = importlib_metadata.version("torch_npu") |
81 | | - logger.info(f"torch_npu version {_torch_npu_version} available.") |
82 | | - except ImportError: |
83 | | - _torch_npu_available = False |
84 | | - |
85 | 78 | _jax_version = "N/A" |
86 | 79 | _flax_version = "N/A" |
87 | 80 | if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: |
|
97 | 90 | _flax_available = False |
98 | 91 |
|
99 | 92 | if USE_SAFETENSORS in ENV_VARS_TRUE_AND_AUTO_VALUES: |
100 | | - _safetensors_available = importlib.util.find_spec("safetensors") is not None |
101 | | - if _safetensors_available: |
102 | | - try: |
103 | | - _safetensors_version = importlib_metadata.version("safetensors") |
104 | | - logger.info(f"Safetensors version {_safetensors_version} available.") |
105 | | - except importlib_metadata.PackageNotFoundError: |
106 | | - _safetensors_available = False |
| 93 | + _safetensors_available, _safetensors_version = _is_package_available("safetensors") |
| 94 | + |
107 | 95 | else: |
108 | 96 | logger.info("Disabling Safetensors because USE_TF is set") |
109 | 97 | _safetensors_available = False |
110 | 98 |
|
111 | | -_transformers_available = importlib.util.find_spec("transformers") is not None |
112 | | -try: |
113 | | - _transformers_version = importlib_metadata.version("transformers") |
114 | | - logger.debug(f"Successfully imported transformers version {_transformers_version}") |
115 | | -except importlib_metadata.PackageNotFoundError: |
116 | | - _transformers_available = False |
117 | | - |
118 | | -_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None |
119 | | -try: |
120 | | - _hf_hub_version = importlib_metadata.version("huggingface_hub") |
121 | | - logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}") |
122 | | -except importlib_metadata.PackageNotFoundError: |
123 | | - _hf_hub_available = False |
124 | | - |
125 | | - |
126 | | -_inflect_available = importlib.util.find_spec("inflect") is not None |
127 | | -try: |
128 | | - _inflect_version = importlib_metadata.version("inflect") |
129 | | - logger.debug(f"Successfully imported inflect version {_inflect_version}") |
130 | | -except importlib_metadata.PackageNotFoundError: |
131 | | - _inflect_available = False |
132 | | - |
133 | | - |
134 | | -_unidecode_available = importlib.util.find_spec("unidecode") is not None |
135 | | -try: |
136 | | - _unidecode_version = importlib_metadata.version("unidecode") |
137 | | - logger.debug(f"Successfully imported unidecode version {_unidecode_version}") |
138 | | -except importlib_metadata.PackageNotFoundError: |
139 | | - _unidecode_available = False |
140 | | - |
141 | 99 | _onnxruntime_version = "N/A" |
142 | 100 | _onnx_available = importlib.util.find_spec("onnxruntime") is not None |
143 | 101 | if _onnx_available: |
|
186 | 144 | except importlib_metadata.PackageNotFoundError: |
187 | 145 | _opencv_available = False |
188 | 146 |
|
189 | | -_scipy_available = importlib.util.find_spec("scipy") is not None |
190 | | -try: |
191 | | - _scipy_version = importlib_metadata.version("scipy") |
192 | | - logger.debug(f"Successfully imported scipy version {_scipy_version}") |
193 | | -except importlib_metadata.PackageNotFoundError: |
194 | | - _scipy_available = False |
195 | | - |
196 | | -_librosa_available = importlib.util.find_spec("librosa") is not None |
197 | | -try: |
198 | | - _librosa_version = importlib_metadata.version("librosa") |
199 | | - logger.debug(f"Successfully imported librosa version {_librosa_version}") |
200 | | -except importlib_metadata.PackageNotFoundError: |
201 | | - _librosa_available = False |
202 | | - |
203 | | -_accelerate_available = importlib.util.find_spec("accelerate") is not None |
204 | | -try: |
205 | | - _accelerate_version = importlib_metadata.version("accelerate") |
206 | | - logger.debug(f"Successfully imported accelerate version {_accelerate_version}") |
207 | | -except importlib_metadata.PackageNotFoundError: |
208 | | - _accelerate_available = False |
209 | | - |
210 | | -_xformers_available = importlib.util.find_spec("xformers") is not None |
211 | | -try: |
212 | | - _xformers_version = importlib_metadata.version("xformers") |
213 | | - if _torch_available: |
214 | | - _torch_version = importlib_metadata.version("torch") |
215 | | - if version.Version(_torch_version) < version.Version("1.12"): |
216 | | - raise ValueError("xformers is installed in your environment and requires PyTorch >= 1.12") |
217 | | - |
218 | | - logger.debug(f"Successfully imported xformers version {_xformers_version}") |
219 | | -except importlib_metadata.PackageNotFoundError: |
220 | | - _xformers_available = False |
221 | | - |
222 | | -_k_diffusion_available = importlib.util.find_spec("k_diffusion") is not None |
223 | | -try: |
224 | | - _k_diffusion_version = importlib_metadata.version("k_diffusion") |
225 | | - logger.debug(f"Successfully imported k-diffusion version {_k_diffusion_version}") |
226 | | -except importlib_metadata.PackageNotFoundError: |
227 | | - _k_diffusion_available = False |
228 | | - |
229 | | -_note_seq_available = importlib.util.find_spec("note_seq") is not None |
230 | | -try: |
231 | | - _note_seq_version = importlib_metadata.version("note_seq") |
232 | | - logger.debug(f"Successfully imported note-seq version {_note_seq_version}") |
233 | | -except importlib_metadata.PackageNotFoundError: |
234 | | - _note_seq_available = False |
235 | | - |
236 | | -_wandb_available = importlib.util.find_spec("wandb") is not None |
237 | | -try: |
238 | | - _wandb_version = importlib_metadata.version("wandb") |
239 | | - logger.debug(f"Successfully imported wandb version {_wandb_version }") |
240 | | -except importlib_metadata.PackageNotFoundError: |
241 | | - _wandb_available = False |
242 | | - |
243 | | - |
244 | | -_tensorboard_available = importlib.util.find_spec("tensorboard") |
245 | | -try: |
246 | | - _tensorboard_version = importlib_metadata.version("tensorboard") |
247 | | - logger.debug(f"Successfully imported tensorboard version {_tensorboard_version}") |
248 | | -except importlib_metadata.PackageNotFoundError: |
249 | | - _tensorboard_available = False |
250 | | - |
251 | | - |
252 | | -_compel_available = importlib.util.find_spec("compel") |
253 | | -try: |
254 | | - _compel_version = importlib_metadata.version("compel") |
255 | | - logger.debug(f"Successfully imported compel version {_compel_version}") |
256 | | -except importlib_metadata.PackageNotFoundError: |
257 | | - _compel_available = False |
258 | | - |
259 | | - |
260 | | -_ftfy_available = importlib.util.find_spec("ftfy") is not None |
261 | | -try: |
262 | | - _ftfy_version = importlib_metadata.version("ftfy") |
263 | | - logger.debug(f"Successfully imported ftfy version {_ftfy_version}") |
264 | | -except importlib_metadata.PackageNotFoundError: |
265 | | - _ftfy_available = False |
266 | | - |
267 | | - |
268 | 147 | _bs4_available = importlib.util.find_spec("bs4") is not None |
269 | 148 | try: |
270 | 149 | # importlib metadata under different name |
|
273 | 152 | except importlib_metadata.PackageNotFoundError: |
274 | 153 | _bs4_available = False |
275 | 154 |
|
276 | | -_torchsde_available = importlib.util.find_spec("torchsde") is not None |
277 | | -try: |
278 | | - _torchsde_version = importlib_metadata.version("torchsde") |
279 | | - logger.debug(f"Successfully imported torchsde version {_torchsde_version}") |
280 | | -except importlib_metadata.PackageNotFoundError: |
281 | | - _torchsde_available = False |
282 | | - |
283 | 155 | _invisible_watermark_available = importlib.util.find_spec("imwatermark") is not None |
284 | 156 | try: |
285 | 157 | _invisible_watermark_version = importlib_metadata.version("invisible-watermark") |
286 | 158 | logger.debug(f"Successfully imported invisible-watermark version {_invisible_watermark_version}") |
287 | 159 | except importlib_metadata.PackageNotFoundError: |
288 | 160 | _invisible_watermark_available = False |
289 | 161 |
|
290 | | - |
291 | | -_peft_available = importlib.util.find_spec("peft") is not None |
292 | | -try: |
293 | | - _peft_version = importlib_metadata.version("peft") |
294 | | - logger.debug(f"Successfully imported peft version {_peft_version}") |
295 | | -except importlib_metadata.PackageNotFoundError: |
296 | | - _peft_available = False |
297 | | - |
298 | | -_torchvision_available = importlib.util.find_spec("torchvision") is not None |
299 | | -try: |
300 | | - _torchvision_version = importlib_metadata.version("torchvision") |
301 | | - logger.debug(f"Successfully imported torchvision version {_torchvision_version}") |
302 | | -except importlib_metadata.PackageNotFoundError: |
303 | | - _torchvision_available = False |
304 | | - |
305 | | -_sentencepiece_available = importlib.util.find_spec("sentencepiece") is not None |
306 | | -try: |
307 | | - _sentencepiece_version = importlib_metadata.version("sentencepiece") |
308 | | - logger.info(f"Successfully imported sentencepiece version {_sentencepiece_version}") |
309 | | -except importlib_metadata.PackageNotFoundError: |
310 | | - _sentencepiece_available = False |
311 | | - |
312 | | -_matplotlib_available = importlib.util.find_spec("matplotlib") is not None |
313 | | -try: |
314 | | - _matplotlib_version = importlib_metadata.version("matplotlib") |
315 | | - logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}") |
316 | | -except importlib_metadata.PackageNotFoundError: |
317 | | - _matplotlib_available = False |
318 | | - |
319 | | -_timm_available = importlib.util.find_spec("timm") is not None |
320 | | -if _timm_available: |
321 | | - try: |
322 | | - _timm_version = importlib_metadata.version("timm") |
323 | | - logger.info(f"Timm version {_timm_version} available.") |
324 | | - except importlib_metadata.PackageNotFoundError: |
325 | | - _timm_available = False |
326 | | - |
327 | | - |
328 | | -def is_timm_available(): |
329 | | - return _timm_available |
330 | | - |
331 | | - |
332 | | -_bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None |
333 | | -try: |
334 | | - _bitsandbytes_version = importlib_metadata.version("bitsandbytes") |
335 | | - logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}") |
336 | | -except importlib_metadata.PackageNotFoundError: |
337 | | - _bitsandbytes_available = False |
338 | | - |
339 | | -_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ) |
340 | | - |
341 | | -_imageio_available = importlib.util.find_spec("imageio") is not None |
342 | | -if _imageio_available: |
343 | | - try: |
344 | | - _imageio_version = importlib_metadata.version("imageio") |
345 | | - logger.debug(f"Successfully imported imageio version {_imageio_version}") |
346 | | - |
347 | | - except importlib_metadata.PackageNotFoundError: |
348 | | - _imageio_available = False |
349 | | - |
350 | | -_is_gguf_available = importlib.util.find_spec("gguf") is not None |
351 | | -if _is_gguf_available: |
352 | | - try: |
353 | | - _gguf_version = importlib_metadata.version("gguf") |
354 | | - logger.debug(f"Successfully import gguf version {_gguf_version}") |
355 | | - except importlib_metadata.PackageNotFoundError: |
356 | | - _is_gguf_available = False |
357 | | - |
358 | | - |
359 | | -_is_torchao_available = importlib.util.find_spec("torchao") is not None |
360 | | -if _is_torchao_available: |
361 | | - try: |
362 | | - _torchao_version = importlib_metadata.version("torchao") |
363 | | - logger.debug(f"Successfully import torchao version {_torchao_version}") |
364 | | - except importlib_metadata.PackageNotFoundError: |
365 | | - _is_torchao_available = False |
366 | | - |
367 | | - |
368 | | -_is_optimum_quanto_available = importlib.util.find_spec("optimum") is not None |
369 | | -if _is_optimum_quanto_available: |
| 162 | +_torch_xla_available, _torch_xla_version = _is_package_available("torch_xla") |
| 163 | +_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu") |
| 164 | +_transformers_available, _transformers_version = _is_package_available("transformers") |
| 165 | +_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub") |
| 166 | +_inflect_available, _inflect_version = _is_package_available("inflect") |
| 167 | +_unidecode_available, _unidecode_version = _is_package_available("unidecode") |
| 168 | +_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion") |
| 169 | +_note_seq_available, _note_seq_version = _is_package_available("note_seq") |
| 170 | +_wandb_available, _wandb_version = _is_package_available("wandb") |
| 171 | +_tensorboard_available, _tensorboard_version = _is_package_available("tensorboard") |
| 172 | +_compel_available, _compel_version = _is_package_available("compel") |
| 173 | +_sentencepiece_available, _sentencepiece_version = _is_package_available("sentencepiece") |
| 174 | +_torchsde_available, _torchsde_version = _is_package_available("torchsde") |
| 175 | +_peft_available, _peft_version = _is_package_available("peft") |
| 176 | +_torchvision_available, _torchvision_version = _is_package_available("torchvision") |
| 177 | +_matplotlib_available, _matplotlib_version = _is_package_available("matplotlib") |
| 178 | +_timm_available, _timm_version = _is_package_available("timm") |
| 179 | +_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") |
| 180 | +_imageio_available, _imageio_version = _is_package_available("imageio") |
| 181 | +_ftfy_available, _ftfy_version = _is_package_available("ftfy") |
| 182 | +_scipy_available, _scipy_version = _is_package_available("scipy") |
| 183 | +_librosa_available, _librosa_version = _is_package_available("librosa") |
| 184 | +_accelerate_available, _accelerate_version = _is_package_available("accelerate") |
| 185 | +_xformers_available, _xformers_version = _is_package_available("xformers") |
| 186 | +_gguf_available, _gguf_version = _is_package_available("gguf") |
| 187 | +_torchao_available, _torchao_version = _is_package_available("torchao") |
| 188 | +_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes") |
| 189 | + |
| 190 | + |
| 191 | +_optimum_quanto_available = importlib.util.find_spec("optimum") is not None |
| 192 | +if _optimum_quanto_available: |
370 | 193 | try: |
371 | 194 | _optimum_quanto_version = importlib_metadata.version("optimum_quanto") |
372 | 195 | logger.debug(f"Successfully import optimum-quanto version {_optimum_quanto_version}") |
373 | 196 | except importlib_metadata.PackageNotFoundError: |
374 | | - _is_optimum_quanto_available = False |
| 197 | + _optimum_quanto_available = False |
375 | 198 |
|
376 | 199 |
|
377 | 200 | def is_torch_available(): |
@@ -495,15 +318,19 @@ def is_imageio_available(): |
495 | 318 |
|
496 | 319 |
|
497 | 320 | def is_gguf_available(): |
498 | | - return _is_gguf_available |
| 321 | + return _gguf_available |
499 | 322 |
|
500 | 323 |
|
501 | 324 | def is_torchao_available(): |
502 | | - return _is_torchao_available |
| 325 | + return _torchao_available |
503 | 326 |
|
504 | 327 |
|
505 | 328 | def is_optimum_quanto_available(): |
506 | | - return _is_optimum_quanto_available |
| 329 | + return _optimum_quanto_available |
| 330 | + |
| 331 | + |
| 332 | +def is_timm_available(): |
| 333 | + return _timm_available |
507 | 334 |
|
508 | 335 |
|
509 | 336 | # docstyle-ignore |
@@ -863,7 +690,7 @@ def is_gguf_version(operation: str, version: str): |
863 | 690 | version (`str`): |
864 | 691 | A version string |
865 | 692 | """ |
866 | | - if not _is_gguf_available: |
| 693 | + if not _gguf_available: |
867 | 694 | return False |
868 | 695 | return compare_versions(parse(_gguf_version), operation, version) |
869 | 696 |
|
@@ -893,7 +720,7 @@ def is_optimum_quanto_version(operation: str, version: str): |
893 | 720 | version (`str`): |
894 | 721 | A version string |
895 | 722 | """ |
896 | | - if not _is_optimum_quanto_available: |
| 723 | + if not _optimum_quanto_available: |
897 | 724 | return False |
898 | 725 | return compare_versions(parse(_optimum_quanto_version), operation, version) |
899 | 726 |
|
|
0 commit comments