Skip to content

Commit 35d8186

Browse files
[Bad dependencies] Fix imports (#1382)
* fix imports * better error * up * finish
1 parent 1524122 commit 35d8186

File tree

4 files changed

+48
-4
lines changed

4 files changed

+48
-4
lines changed

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
import PIL
77
from PIL import Image
88

9-
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_torch_available, is_transformers_available
9+
from ...utils import (
10+
BaseOutput,
11+
is_flax_available,
12+
is_onnx_available,
13+
is_torch_available,
14+
is_transformers_available,
15+
is_transformers_version,
16+
)
1017

1118

1219
@dataclass
@@ -30,12 +37,16 @@ class StableDiffusionPipelineOutput(BaseOutput):
3037
if is_transformers_available() and is_torch_available():
3138
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
3239
from .pipeline_stable_diffusion import StableDiffusionPipeline
33-
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
3440
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
3541
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
3642
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
3743
from .safety_checker import StableDiffusionSafetyChecker
3844

45+
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0"):
46+
from .pipeline_stable_diffusion_image_variation import StableDiffusionImageVariationPipeline
47+
else:
48+
from ...utils.dummy_torch_and_transformers_objects import StableDiffusionImageVariationPipeline
49+
3950
if is_transformers_available() and is_onnx_available():
4051
from .pipeline_onnx_stable_diffusion import OnnxStableDiffusionPipeline, StableDiffusionOnnxPipeline
4152
from .pipeline_onnx_stable_diffusion_img2img import OnnxStableDiffusionImg2ImgPipeline
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1-
from ...utils import is_torch_available, is_transformers_available
1+
from ...utils import is_torch_available, is_transformers_available, is_transformers_version
22

33

4-
if is_transformers_available() and is_torch_available():
4+
if is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.25.0"):
55
from .modeling_text_unet import UNetFlatConditionModel
66
from .pipeline_versatile_diffusion import VersatileDiffusionPipeline
77
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
88
from .pipeline_versatile_diffusion_image_variation import VersatileDiffusionImageVariationPipeline
99
from .pipeline_versatile_diffusion_text_to_image import VersatileDiffusionTextToImagePipeline
10+
else:
11+
from ...utils.dummy_torch_and_transformers_objects import (
12+
VersatileDiffusionDualGuidedPipeline,
13+
VersatileDiffusionImageVariationPipeline,
14+
VersatileDiffusionPipeline,
15+
VersatileDiffusionTextToImagePipeline,
16+
)

src/diffusers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
is_torch_available,
3434
is_torch_version,
3535
is_transformers_available,
36+
is_transformers_version,
3637
is_unidecode_available,
3738
requires_backends,
3839
)

src/diffusers/utils/import_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,17 @@ def requires_backends(obj, backends):
303303
if failed:
304304
raise ImportError("".join(failed))
305305

306+
if name in [
307+
"VersatileDiffusionTextToImagePipeline",
308+
"VersatileDiffusionPipeline",
309+
"VersatileDiffusionDualGuidedPipeline",
310+
"StableDiffusionImageVariationPipeline",
311+
] and is_transformers_version("<", "4.25.0"):
312+
raise ImportError(
313+
f"You need to install `transformers` from 'main' in order to use {name}: \n```\n pip install"
314+
" git+https://github.com/huggingface/transformers \n```"
315+
)
316+
306317

307318
class DummyObject(type):
308319
"""
@@ -347,3 +358,17 @@ def is_torch_version(operation: str, version: str):
347358
A string version of PyTorch
348359
"""
349360
return compare_versions(parse(_torch_version), operation, version)
361+
362+
363+
def is_transformers_version(operation: str, version: str):
364+
"""
365+
Args:
366+
Compares the current Transformers version to a given reference with an operation.
367+
operation (`str`):
368+
A string representation of an operator, such as `">"` or `"<="`
369+
version (`str`):
370+
A string version of PyTorch
371+
"""
372+
if not _transformers_available:
373+
return False
374+
return compare_versions(parse(_transformers_version), operation, version)

0 commit comments

Comments
 (0)