Skip to content

Commit 8094f66

Browse files
committed
up
1 parent 15e3a0f commit 8094f66

11 files changed

+163
-182
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 132 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@
4444

4545
from .. import __version__
4646
from ..configuration_utils import ConfigMixin
47+
from ..models import AutoencoderKL
48+
from ..models.attention_processor import FusedAttnProcessor2_0
4749
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
4850
from ..quantizers import PipelineQuantizationConfig
4951
from ..quantizers.bitsandbytes.utils import _check_bnb_status
@@ -2171,13 +2173,136 @@ def _maybe_raise_error_if_group_offload_active(
21712173

21722174

21732175
class StableDiffusionMixin:
2174-
def __init__(self, *args, **kwargs):
2175-
deprecation_message = "`StableDiffusionMixin` from `diffusers.pipelines.pipeline_utils` is deprecated and this will be removed in a future version. Please use `StableDiffusionMixin` from `diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils`, instead."
2176-
deprecate("StableDiffusionMixin", "1.0.0", deprecation_message)
2176+
r"""
2177+
Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion)
2178+
"""
21772179

2178-
# To avoid circular imports and for being backwards-compatible.
2179-
from .stable_diffusion.pipeline_stable_diffusion_utils import (
2180-
StableDiffusionMixin as ActualStableDiffusionMixin,
2180+
def enable_vae_slicing(self):
2181+
r"""
2182+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
2183+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
2184+
"""
2185+
depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`."
2186+
deprecate(
2187+
"enable_vae_slicing",
2188+
"0.40.0",
2189+
depr_message,
21812190
)
2191+
self.vae.enable_slicing()
21822192

2183-
ActualStableDiffusionMixin.__init__(self, *args, **kwargs)
2193+
def disable_vae_slicing(self):
2194+
r"""
2195+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
2196+
computing decoding in one step.
2197+
"""
2198+
depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`."
2199+
deprecate(
2200+
"disable_vae_slicing",
2201+
"0.40.0",
2202+
depr_message,
2203+
)
2204+
self.vae.disable_slicing()
2205+
2206+
def enable_vae_tiling(self):
2207+
r"""
2208+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
2209+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
2210+
processing larger images.
2211+
"""
2212+
depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`."
2213+
deprecate(
2214+
"enable_vae_tiling",
2215+
"0.40.0",
2216+
depr_message,
2217+
)
2218+
self.vae.enable_tiling()
2219+
2220+
def disable_vae_tiling(self):
2221+
r"""
2222+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
2223+
computing decoding in one step.
2224+
"""
2225+
depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`."
2226+
deprecate(
2227+
"disable_vae_tiling",
2228+
"0.40.0",
2229+
depr_message,
2230+
)
2231+
self.vae.disable_tiling()
2232+
2233+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
2234+
r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497.
2235+
2236+
The suffixes after the scaling factors represent the stages where they are being applied.
2237+
2238+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
2239+
that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
2240+
2241+
Args:
2242+
s1 (`float`):
2243+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
2244+
mitigate "oversmoothing effect" in the enhanced denoising process.
2245+
s2 (`float`):
2246+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
2247+
mitigate "oversmoothing effect" in the enhanced denoising process.
2248+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
2249+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
2250+
"""
2251+
if not hasattr(self, "unet"):
2252+
raise ValueError("The pipeline must have `unet` for using FreeU.")
2253+
self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
2254+
2255+
def disable_freeu(self):
2256+
"""Disables the FreeU mechanism if enabled."""
2257+
self.unet.disable_freeu()
2258+
2259+
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
2260+
"""
2261+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
2262+
are fused. For cross-attention modules, key and value projection matrices are fused.
2263+
2264+
> [!WARNING] > This API is 🧪 experimental.
2265+
2266+
Args:
2267+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
2268+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
2269+
"""
2270+
self.fusing_unet = False
2271+
self.fusing_vae = False
2272+
2273+
if unet:
2274+
self.fusing_unet = True
2275+
self.unet.fuse_qkv_projections()
2276+
self.unet.set_attn_processor(FusedAttnProcessor2_0())
2277+
2278+
if vae:
2279+
if not isinstance(self.vae, AutoencoderKL):
2280+
raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.")
2281+
2282+
self.fusing_vae = True
2283+
self.vae.fuse_qkv_projections()
2284+
self.vae.set_attn_processor(FusedAttnProcessor2_0())
2285+
2286+
def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True):
2287+
"""Disable QKV projection fusion if enabled.
2288+
2289+
> [!WARNING] > This API is 🧪 experimental.
2290+
2291+
Args:
2292+
unet (`bool`, defaults to `True`): To apply fusion on the UNet.
2293+
vae (`bool`, defaults to `True`): To apply fusion on the VAE.
2294+
2295+
"""
2296+
if unet:
2297+
if not self.fusing_unet:
2298+
logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.")
2299+
else:
2300+
self.unet.unfuse_qkv_projections()
2301+
self.fusing_unet = False
2302+
2303+
if vae:
2304+
if not self.fusing_vae:
2305+
logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.")
2306+
else:
2307+
self.vae.unfuse_qkv_projections()
2308+
self.fusing_vae = False

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535
from ...utils.torch_utils import randn_tensor
3636
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
37-
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents
37+
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents
3838

3939

4040
if is_torch_xla_available():
@@ -72,7 +72,7 @@ def preprocess(image):
7272

7373

7474
class StableDiffusionDepth2ImgPipeline(
75-
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
75+
DiffusionPipeline, SDMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
7676
):
7777
r"""
7878
Pipeline for text-guided depth-based image-to-image generation using Stable Diffusion.

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
from ...schedulers import KarrasDiffusionSchedulers
2626
from ...utils import deprecate, is_torch_xla_available, logging
2727
from ...utils.torch_utils import randn_tensor
28-
from ..pipeline_utils import DiffusionPipeline
28+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
2929
from . import StableDiffusionPipelineOutput
30-
from .pipeline_stable_diffusion_utils import StableDiffusionMixin
30+
from .pipeline_stable_diffusion_utils import SDMixin
3131
from .safety_checker import StableDiffusionSafetyChecker
3232

3333

@@ -41,7 +41,7 @@
4141
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4242

4343

44-
class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin):
44+
class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin, SDMixin):
4545
r"""
4646
Pipeline to generate image variations from an input image using Stable Diffusion.
4747

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
replace_example_docstring,
3535
)
3636
from ...utils.torch_utils import randn_tensor
37-
from ..pipeline_utils import DiffusionPipeline
37+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3838
from . import StableDiffusionPipelineOutput
39-
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents, retrieve_timesteps
39+
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps
4040
from .safety_checker import StableDiffusionSafetyChecker
4141

4242

@@ -105,6 +105,7 @@ def preprocess(image):
105105
class StableDiffusionImg2ImgPipeline(
106106
DiffusionPipeline,
107107
StableDiffusionMixin,
108+
SDMixin,
108109
TextualInversionLoaderMixin,
109110
IPAdapterMixin,
110111
StableDiffusionLoraLoaderMixin,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,11 @@
2525
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
2626
from ...models import AsymmetricAutoencoderKL, AutoencoderKL, ImageProjection, UNet2DConditionModel
2727
from ...schedulers import KarrasDiffusionSchedulers
28-
from ...utils import (
29-
deprecate,
30-
is_torch_xla_available,
31-
logging,
32-
)
28+
from ...utils import deprecate, is_torch_xla_available, logging
3329
from ...utils.torch_utils import randn_tensor
34-
from ..pipeline_utils import DiffusionPipeline
30+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3531
from . import StableDiffusionPipelineOutput
36-
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents, retrieve_timesteps
32+
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents, retrieve_timesteps
3733
from .safety_checker import StableDiffusionSafetyChecker
3834

3935

@@ -50,6 +46,7 @@
5046
class StableDiffusionInpaintPipeline(
5147
DiffusionPipeline,
5248
StableDiffusionMixin,
49+
SDMixin,
5350
TextualInversionLoaderMixin,
5451
IPAdapterMixin,
5552
StableDiffusionLoraLoaderMixin,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@
2626
from ...schedulers import KarrasDiffusionSchedulers
2727
from ...utils import PIL_INTERPOLATION, deprecate, is_torch_xla_available, logging
2828
from ...utils.torch_utils import randn_tensor
29-
from ..pipeline_utils import DiffusionPipeline
29+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3030
from . import StableDiffusionPipelineOutput
31-
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents
31+
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents
3232
from .safety_checker import StableDiffusionSafetyChecker
3333

3434

@@ -69,6 +69,7 @@ def preprocess(image):
6969
class StableDiffusionInstructPix2PixPipeline(
7070
DiffusionPipeline,
7171
StableDiffusionMixin,
72+
SDMixin,
7273
TextualInversionLoaderMixin,
7374
StableDiffusionLoraLoaderMixin,
7475
IPAdapterMixin,

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
from ...schedulers import EulerDiscreteScheduler
2828
from ...utils import deprecate, is_torch_xla_available, logging
2929
from ...utils.torch_utils import randn_tensor
30-
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
31-
from .pipeline_stable_diffusion_utils import StableDiffusionMixin, retrieve_latents
30+
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput, StableDiffusionMixin
31+
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents
3232

3333

3434
if is_torch_xla_available():
@@ -68,7 +68,7 @@ def preprocess(image):
6868
return image
6969

7070

71-
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, FromSingleFileMixin):
71+
class StableDiffusionLatentUpscalePipeline(DiffusionPipeline, StableDiffusionMixin, SDMixin, FromSingleFileMixin):
7272
r"""
7373
Pipeline for upscaling Stable Diffusion output image resolution by a factor of 2.
7474

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,11 @@
2525
from ...models import AutoencoderKL, UNet2DConditionModel
2626
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
2727
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
28-
from ...utils import (
29-
deprecate,
30-
is_torch_xla_available,
31-
logging,
32-
)
28+
from ...utils import deprecate, is_torch_xla_available, logging
3329
from ...utils.torch_utils import randn_tensor
34-
from ..pipeline_utils import DiffusionPipeline
30+
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3531
from . import StableDiffusionPipelineOutput
36-
from .pipeline_stable_diffusion_utils import StableDiffusionMixin
32+
from .pipeline_stable_diffusion_utils import SDMixin
3733

3834

3935
if is_torch_xla_available():
@@ -75,6 +71,7 @@ def preprocess(image):
7571
class StableDiffusionUpscalePipeline(
7672
DiffusionPipeline,
7773
StableDiffusionMixin,
74+
SDMixin,
7875
TextualInversionLoaderMixin,
7976
StableDiffusionLoraLoaderMixin,
8077
FromSingleFileMixin,

0 commit comments

Comments
 (0)