Skip to content

Commit 7c2a58f

Browse files
Move accelerate to a soft-dependency (#1134)
* finish * finish * Update src/diffusers/modeling_utils.py * Update src/diffusers/pipeline_utils.py Co-authored-by: Anton Lozhkov <[email protected]> * more fixes * fix Co-authored-by: Anton Lozhkov <[email protected]>
1 parent bde4880 commit 7c2a58f

File tree

8 files changed

+82
-482
lines changed

8 files changed

+82
-482
lines changed

src/diffusers/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from .utils import (
2-
is_accelerate_available,
32
is_flax_available,
43
is_inflect_available,
54
is_onnx_available,
@@ -17,13 +16,6 @@
1716
from .utils import logging
1817

1918

20-
# This will create an extra dummy file "dummy_torch_and_accelerate_objects.py"
21-
# TODO: (patil-suraj, anton-l) maybe import everything under is_torch_and_accelerate_available
22-
if is_torch_available() and not is_accelerate_available():
23-
error_msg = "Please install the `accelerate` library to use Diffusers with PyTorch. You can do so by running `pip install diffusers[torch]`. Or if torch is already installed, you can run `pip install accelerate`." # noqa: E501
24-
raise ImportError(error_msg)
25-
26-
2719
if is_torch_available():
2820
from .modeling_utils import ModelMixin
2921
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel

src/diffusers/modeling_utils.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,20 @@
2121
import torch
2222
from torch import Tensor, device
2323

24-
import accelerate
25-
from accelerate.utils import set_module_tensor_to_device
26-
from accelerate.utils.versions import is_torch_version
2724
from huggingface_hub import hf_hub_download
2825
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
2926
from requests import HTTPError
3027

3128
from . import __version__
32-
from .utils import CONFIG_NAME, DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, WEIGHTS_NAME, logging
29+
from .utils import (
30+
CONFIG_NAME,
31+
DIFFUSERS_CACHE,
32+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
33+
WEIGHTS_NAME,
34+
is_accelerate_available,
35+
is_torch_version,
36+
logging,
37+
)
3338

3439

3540
logger = logging.get_logger(__name__)
@@ -41,6 +46,12 @@
4146
_LOW_CPU_MEM_USAGE_DEFAULT = False
4247

4348

49+
if is_accelerate_available():
50+
import accelerate
51+
from accelerate.utils import set_module_tensor_to_device
52+
from accelerate.utils.versions import is_torch_version
53+
54+
4455
def get_parameter_device(parameter: torch.nn.Module):
4556
try:
4657
return next(parameter.parameters()).device
@@ -319,6 +330,21 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
319330
device_map = kwargs.pop("device_map", None)
320331
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
321332

333+
if low_cpu_mem_usage and not is_accelerate_available():
334+
low_cpu_mem_usage = False
335+
logger.warn(
336+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
337+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
338+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
339+
" install accelerate\n```\n."
340+
)
341+
342+
if device_map is not None and not is_accelerate_available():
343+
raise NotImplementedError(
344+
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
345+
" `device_map=None`. You can install accelerate with `pip install accelerate`."
346+
)
347+
322348
# Check if we can handle device_map and dispatching the weights
323349
if device_map is not None and not is_torch_version(">=", "1.9.0"):
324350
raise NotImplementedError(

src/diffusers/pipeline_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import diffusers
2727
import PIL
28-
from accelerate.utils.versions import is_torch_version
2928
from huggingface_hub import snapshot_download
3029
from packaging import version
3130
from PIL import Image
@@ -43,6 +42,8 @@
4342
WEIGHTS_NAME,
4443
BaseOutput,
4544
deprecate,
45+
is_accelerate_available,
46+
is_torch_version,
4647
is_transformers_available,
4748
logging,
4849
)
@@ -397,6 +398,15 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
397398
device_map = kwargs.pop("device_map", None)
398399
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
399400

401+
if low_cpu_mem_usage and not is_accelerate_available():
402+
low_cpu_mem_usage = False
403+
logger.warn(
404+
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
405+
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
406+
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
407+
" install accelerate\n```\n."
408+
)
409+
400410
if device_map is not None and not is_torch_version(">=", "1.9.0"):
401411
raise NotImplementedError(
402412
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
is_scipy_available,
3232
is_tf_available,
3333
is_torch_available,
34+
is_torch_version,
3435
is_transformers_available,
3536
is_unidecode_available,
3637
requires_backends,

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -272,21 +272,6 @@ def from_pretrained(cls, *args, **kwargs):
272272
requires_backends(cls, ["torch"])
273273

274274

275-
class VQDiffusionPipeline(metaclass=DummyObject):
276-
_backends = ["torch"]
277-
278-
def __init__(self, *args, **kwargs):
279-
requires_backends(self, ["torch"])
280-
281-
@classmethod
282-
def from_config(cls, *args, **kwargs):
283-
requires_backends(cls, ["torch"])
284-
285-
@classmethod
286-
def from_pretrained(cls, *args, **kwargs):
287-
requires_backends(cls, ["torch"])
288-
289-
290275
class DDIMScheduler(metaclass=DummyObject):
291276
_backends = ["torch"]
292277

0 commit comments

Comments
 (0)