2121import os
2222import sys
2323from collections import OrderedDict , defaultdict
24+ from functools import lru_cache as cache
2425from itertools import chain
2526from types import ModuleType
2627from typing import Any , Tuple , Union
@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
673674
674675
675676# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
677+ @cache
676678def is_torch_version (operation : str , version : str ):
677679 """
678680 Compares the current PyTorch version to a given reference with an operation.
@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
686688 return compare_versions (parse (_torch_version ), operation , version )
687689
688690
691+ @cache
689692def is_torch_xla_version (operation : str , version : str ):
690693 """
691694 Compares the current torch_xla version to a given reference with an operation.
@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
701704 return compare_versions (parse (_torch_xla_version ), operation , version )
702705
703706
707+ @cache
704708def is_transformers_version (operation : str , version : str ):
705709 """
706710 Compares the current Transformers version to a given reference with an operation.
@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
716720 return compare_versions (parse (_transformers_version ), operation , version )
717721
718722
723+ @cache
719724def is_hf_hub_version (operation : str , version : str ):
720725 """
721726 Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
731736 return compare_versions (parse (_hf_hub_version ), operation , version )
732737
733738
739+ @cache
734740def is_accelerate_version (operation : str , version : str ):
735741 """
736742 Compares the current Accelerate version to a given reference with an operation.
@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
746752 return compare_versions (parse (_accelerate_version ), operation , version )
747753
748754
755+ @cache
749756def is_peft_version (operation : str , version : str ):
750757 """
751758 Compares the current PEFT version to a given reference with an operation.
@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
761768 return compare_versions (parse (_peft_version ), operation , version )
762769
763770
771+ @cache
764772def is_bitsandbytes_version (operation : str , version : str ):
765773 """
766774 Args:
@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
775783 return compare_versions (parse (_bitsandbytes_version ), operation , version )
776784
777785
786+ @cache
778787def is_gguf_version (operation : str , version : str ):
779788 """
780789 Compares the current Accelerate version to a given reference with an operation.
@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
790799 return compare_versions (parse (_gguf_version ), operation , version )
791800
792801
802+ @cache
793803def is_torchao_version (operation : str , version : str ):
794804 """
795805 Compares the current torchao version to a given reference with an operation.
@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
805815 return compare_versions (parse (_torchao_version ), operation , version )
806816
807817
818+ @cache
808819def is_k_diffusion_version (operation : str , version : str ):
809820 """
810821 Compares the current k-diffusion version to a given reference with an operation.
@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
820831 return compare_versions (parse (_k_diffusion_version ), operation , version )
821832
822833
834+ @cache
823835def is_optimum_quanto_version (operation : str , version : str ):
824836 """
825837 Compares the current Accelerate version to a given reference with an operation.
@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
835847 return compare_versions (parse (_optimum_quanto_version ), operation , version )
836848
837849
850+ @cache
838851def is_nvidia_modelopt_version (operation : str , version : str ):
839852 """
840853 Compares the current Nvidia ModelOpt version to a given reference with an operation.
@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
850863 return compare_versions (parse (_nvidia_modelopt_version ), operation , version )
851864
852865
866+ @cache
853867def is_xformers_version (operation : str , version : str ):
854868 """
855869 Compares the current xformers version to a given reference with an operation.
@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
865879 return compare_versions (parse (_xformers_version ), operation , version )
866880
867881
882+ @cache
868883def is_sageattention_version (operation : str , version : str ):
869884 """
870885 Compares the current sageattention version to a given reference with an operation.
@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
880895 return compare_versions (parse (_sageattention_version ), operation , version )
881896
882897
898+ @cache
883899def is_flash_attn_version (operation : str , version : str ):
884900 """
885901 Compares the current flash-attention version to a given reference with an operation.
0 commit comments