Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import sys
from collections import OrderedDict, defaultdict
from functools import lru_cache as cache
from itertools import chain
from types import ModuleType
from typing import Any, Tuple, Union
Expand Down Expand Up @@ -673,6 +674,7 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re


# This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
@cache
def is_torch_version(operation: str, version: str):
"""
Compares the current PyTorch version to a given reference with an operation.
Expand All @@ -686,6 +688,7 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)


@cache
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.
Expand All @@ -701,6 +704,7 @@ def is_torch_xla_version(operation: str, version: str):
return compare_versions(parse(_torch_xla_version), operation, version)


@cache
def is_transformers_version(operation: str, version: str):
"""
Compares the current Transformers version to a given reference with an operation.
Expand All @@ -716,6 +720,7 @@ def is_transformers_version(operation: str, version: str):
return compare_versions(parse(_transformers_version), operation, version)


@cache
def is_hf_hub_version(operation: str, version: str):
"""
Compares the current Hugging Face Hub version to a given reference with an operation.
Expand All @@ -731,6 +736,7 @@ def is_hf_hub_version(operation: str, version: str):
return compare_versions(parse(_hf_hub_version), operation, version)


@cache
def is_accelerate_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Expand All @@ -746,6 +752,7 @@ def is_accelerate_version(operation: str, version: str):
return compare_versions(parse(_accelerate_version), operation, version)


@cache
def is_peft_version(operation: str, version: str):
"""
Compares the current PEFT version to a given reference with an operation.
Expand All @@ -761,6 +768,7 @@ def is_peft_version(operation: str, version: str):
return compare_versions(parse(_peft_version), operation, version)


@cache
def is_bitsandbytes_version(operation: str, version: str):
"""
Args:
Expand All @@ -775,6 +783,7 @@ def is_bitsandbytes_version(operation: str, version: str):
return compare_versions(parse(_bitsandbytes_version), operation, version)


@cache
def is_gguf_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Expand All @@ -790,6 +799,7 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)


@cache
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
Expand All @@ -805,6 +815,7 @@ def is_torchao_version(operation: str, version: str):
return compare_versions(parse(_torchao_version), operation, version)


@cache
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.
Expand All @@ -820,6 +831,7 @@ def is_k_diffusion_version(operation: str, version: str):
return compare_versions(parse(_k_diffusion_version), operation, version)


@cache
def is_optimum_quanto_version(operation: str, version: str):
"""
Compares the current Accelerate version to a given reference with an operation.
Expand All @@ -835,6 +847,7 @@ def is_optimum_quanto_version(operation: str, version: str):
return compare_versions(parse(_optimum_quanto_version), operation, version)


@cache
def is_nvidia_modelopt_version(operation: str, version: str):
"""
Compares the current Nvidia ModelOpt version to a given reference with an operation.
Expand All @@ -850,6 +863,7 @@ def is_nvidia_modelopt_version(operation: str, version: str):
return compare_versions(parse(_nvidia_modelopt_version), operation, version)


@cache
def is_xformers_version(operation: str, version: str):
"""
Compares the current xformers version to a given reference with an operation.
Expand All @@ -865,6 +879,7 @@ def is_xformers_version(operation: str, version: str):
return compare_versions(parse(_xformers_version), operation, version)


@cache
def is_sageattention_version(operation: str, version: str):
"""
Compares the current sageattention version to a given reference with an operation.
Expand All @@ -880,6 +895,7 @@ def is_sageattention_version(operation: str, version: str):
return compare_versions(parse(_sageattention_version), operation, version)


@cache
def is_flash_attn_version(operation: str, version: str):
"""
Compares the current flash-attention version to a given reference with an operation.
Expand Down
Loading