Skip to content

Commit cf4b97b

Browse files
authored
[perf] Cache version checks (huggingface#12399)
1 parent 7f3e9b8 commit cf4b97b

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

src/diffusers/utils/import_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import sys
2323
from collections import OrderedDict, defaultdict
24+
from functools import lru_cache as cache
2425
from itertools import chain
2526
from types import ModuleType
2627
from typing import Any, Tuple, Union
@@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
673674

674675

675676
# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
677+
@cache
676678
def is_torch_version(operation: str, version: str):
677679
"""
678680
Compares the current PyTorch version to a given reference with an operation.
@@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
686688
return compare_versions(parse(_torch_version), operation, version)
687689

688690

691+
@cache
689692
def is_torch_xla_version(operation: str, version: str):
690693
"""
691694
Compares the current torch_xla version to a given reference with an operation.
@@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
701704
return compare_versions(parse(_torch_xla_version), operation, version)
702705

703706

707+
@cache
704708
def is_transformers_version(operation: str, version: str):
705709
"""
706710
Compares the current Transformers version to a given reference with an operation.
@@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
716720
return compare_versions(parse(_transformers_version), operation, version)
717721

718722

723+
@cache
719724
def is_hf_hub_version(operation: str, version: str):
720725
"""
721726
Compares the current Hugging Face Hub version to a given reference with an operation.
@@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
731736
return compare_versions(parse(_hf_hub_version), operation, version)
732737

733738

739+
@cache
734740
def is_accelerate_version(operation: str, version: str):
735741
"""
736742
Compares the current Accelerate version to a given reference with an operation.
@@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
746752
return compare_versions(parse(_accelerate_version), operation, version)
747753

748754

755+
@cache
749756
def is_peft_version(operation: str, version: str):
750757
"""
751758
Compares the current PEFT version to a given reference with an operation.
@@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
761768
return compare_versions(parse(_peft_version), operation, version)
762769

763770

771+
@cache
764772
def is_bitsandbytes_version(operation: str, version: str):
765773
"""
766774
Args:
@@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
775783
return compare_versions(parse(_bitsandbytes_version), operation, version)
776784

777785

786+
@cache
778787
def is_gguf_version(operation: str, version: str):
779788
"""
780789
Compares the current Accelerate version to a given reference with an operation.
@@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
790799
return compare_versions(parse(_gguf_version), operation, version)
791800

792801

802+
@cache
793803
def is_torchao_version(operation: str, version: str):
794804
"""
795805
Compares the current torchao version to a given reference with an operation.
@@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
805815
return compare_versions(parse(_torchao_version), operation, version)
806816

807817

818+
@cache
808819
def is_k_diffusion_version(operation: str, version: str):
809820
"""
810821
Compares the current k-diffusion version to a given reference with an operation.
@@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
820831
return compare_versions(parse(_k_diffusion_version), operation, version)
821832

822833

834+
@cache
823835
def is_optimum_quanto_version(operation: str, version: str):
824836
"""
825837
Compares the current Accelerate version to a given reference with an operation.
@@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
835847
return compare_versions(parse(_optimum_quanto_version), operation, version)
836848

837849

850+
@cache
838851
def is_nvidia_modelopt_version(operation: str, version: str):
839852
"""
840853
Compares the current Nvidia ModelOpt version to a given reference with an operation.
@@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
850863
return compare_versions(parse(_nvidia_modelopt_version), operation, version)
851864

852865

866+
@cache
853867
def is_xformers_version(operation: str, version: str):
854868
"""
855869
Compares the current xformers version to a given reference with an operation.
@@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
865879
return compare_versions(parse(_xformers_version), operation, version)
866880

867881

882+
@cache
868883
def is_sageattention_version(operation: str, version: str):
869884
"""
870885
Compares the current sageattention version to a given reference with an operation.
@@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
880895
return compare_versions(parse(_sageattention_version), operation, version)
881896

882897

898+
@cache
883899
def is_flash_attn_version(operation: str, version: str):
884900
"""
885901
Compares the current flash-attention version to a given reference with an operation.

0 commit comments

Comments
 (0)