diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 9399ccd2a7a3..97065267b004 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -21,6 +21,7 @@ import os import sys from collections import OrderedDict, defaultdict +from functools import lru_cache as cache from itertools import chain from types import ModuleType from typing import Any, Tuple, Union @@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338 +@cache def is_torch_version(operation: str, version: str): """ Compares the current PyTorch version to a given reference with an operation. @@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str): return compare_versions(parse(_torch_version), operation, version) +@cache def is_torch_xla_version(operation: str, version: str): """ Compares the current torch_xla version to a given reference with an operation. @@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str): return compare_versions(parse(_torch_xla_version), operation, version) +@cache def is_transformers_version(operation: str, version: str): """ Compares the current Transformers version to a given reference with an operation. @@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str): return compare_versions(parse(_transformers_version), operation, version) +@cache def is_hf_hub_version(operation: str, version: str): """ Compares the current Hugging Face Hub version to a given reference with an operation. @@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str): return compare_versions(parse(_hf_hub_version), operation, version) +@cache def is_accelerate_version(operation: str, version: str): """ Compares the current Accelerate version to a given reference with an operation. @@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str): return compare_versions(parse(_accelerate_version), operation, version) +@cache def is_peft_version(operation: str, version: str): """ Compares the current PEFT version to a given reference with an operation. @@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str): return compare_versions(parse(_peft_version), operation, version) +@cache def is_bitsandbytes_version(operation: str, version: str): """ Args: @@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str): return compare_versions(parse(_bitsandbytes_version), operation, version) +@cache def is_gguf_version(operation: str, version: str): """ Compares the current Accelerate version to a given reference with an operation. @@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str): return compare_versions(parse(_gguf_version), operation, version) +@cache def is_torchao_version(operation: str, version: str): """ Compares the current torchao version to a given reference with an operation. @@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str): return compare_versions(parse(_torchao_version), operation, version) +@cache def is_k_diffusion_version(operation: str, version: str): """ Compares the current k-diffusion version to a given reference with an operation. @@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str): return compare_versions(parse(_k_diffusion_version), operation, version) +@cache def is_optimum_quanto_version(operation: str, version: str): """ Compares the current Accelerate version to a given reference with an operation. @@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +@cache def is_nvidia_modelopt_version(operation: str, version: str): """ Compares the current Nvidia ModelOpt version to a given reference with an operation. @@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str): return compare_versions(parse(_nvidia_modelopt_version), operation, version) +@cache def is_xformers_version(operation: str, version: str): """ Compares the current xformers version to a given reference with an operation. @@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str): return compare_versions(parse(_xformers_version), operation, version) +@cache def is_sageattention_version(operation: str, version: str): """ Compares the current sageattention version to a given reference with an operation. @@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str): return compare_versions(parse(_sageattention_version), operation, version) +@cache def is_flash_attn_version(operation: str, version: str): """ Compares the current flash-attention version to a given reference with an operation.