|
44 | 44 |
|
45 | 45 | from .. import __version__ |
46 | 46 | from ..configuration_utils import ConfigMixin |
| 47 | +from ..models import AutoencoderKL |
| 48 | +from ..models.attention_processor import FusedAttnProcessor2_0 |
47 | 49 | from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin |
48 | 50 | from ..quantizers import PipelineQuantizationConfig |
49 | 51 | from ..quantizers.bitsandbytes.utils import _check_bnb_status |
@@ -2171,13 +2173,136 @@ def _maybe_raise_error_if_group_offload_active( |
2171 | 2173 |
|
2172 | 2174 |
|
2173 | 2175 | class StableDiffusionMixin: |
2174 | | - def __init__(self, *args, **kwargs): |
2175 | | - deprecation_message = "`StableDiffusionMixin` from `diffusers.pipelines.pipeline_utils` is deprecated and this will be removed in a future version. Please use `StableDiffusionMixin` from `diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils`, instead." |
2176 | | - deprecate("StableDiffusionMixin", "1.0.0", deprecation_message) |
| 2176 | + r""" |
| 2177 | + Helper for DiffusionPipeline with vae and unet.(mainly for LDM such as stable diffusion) |
| 2178 | + """ |
2177 | 2179 |
|
2178 | | - # To avoid circular imports and for being backwards-compatible. |
2179 | | - from .stable_diffusion.pipeline_stable_diffusion_utils import ( |
2180 | | - StableDiffusionMixin as ActualStableDiffusionMixin, |
| 2180 | + def enable_vae_slicing(self): |
| 2181 | + r""" |
| 2182 | + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
| 2183 | + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
| 2184 | + """ |
| 2185 | + depr_message = f"Calling `enable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_slicing()`." |
| 2186 | + deprecate( |
| 2187 | + "enable_vae_slicing", |
| 2188 | + "0.40.0", |
| 2189 | + depr_message, |
2181 | 2190 | ) |
| 2191 | + self.vae.enable_slicing() |
2182 | 2192 |
|
2183 | | - ActualStableDiffusionMixin.__init__(self, *args, **kwargs) |
| 2193 | + def disable_vae_slicing(self): |
| 2194 | + r""" |
| 2195 | + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to |
| 2196 | + computing decoding in one step. |
| 2197 | + """ |
| 2198 | + depr_message = f"Calling `disable_vae_slicing()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_slicing()`." |
| 2199 | + deprecate( |
| 2200 | + "disable_vae_slicing", |
| 2201 | + "0.40.0", |
| 2202 | + depr_message, |
| 2203 | + ) |
| 2204 | + self.vae.disable_slicing() |
| 2205 | + |
| 2206 | + def enable_vae_tiling(self): |
| 2207 | + r""" |
| 2208 | + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
| 2209 | + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
| 2210 | + processing larger images. |
| 2211 | + """ |
| 2212 | + depr_message = f"Calling `enable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.enable_tiling()`." |
| 2213 | + deprecate( |
| 2214 | + "enable_vae_tiling", |
| 2215 | + "0.40.0", |
| 2216 | + depr_message, |
| 2217 | + ) |
| 2218 | + self.vae.enable_tiling() |
| 2219 | + |
| 2220 | + def disable_vae_tiling(self): |
| 2221 | + r""" |
| 2222 | + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to |
| 2223 | + computing decoding in one step. |
| 2224 | + """ |
| 2225 | + depr_message = f"Calling `disable_vae_tiling()` on a `{self.__class__.__name__}` is deprecated and this method will be removed in a future version. Please use `pipe.vae.disable_tiling()`." |
| 2226 | + deprecate( |
| 2227 | + "disable_vae_tiling", |
| 2228 | + "0.40.0", |
| 2229 | + depr_message, |
| 2230 | + ) |
| 2231 | + self.vae.disable_tiling() |
| 2232 | + |
| 2233 | + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): |
| 2234 | + r"""Enables the FreeU mechanism as in https://huggingface.co/papers/2309.11497. |
| 2235 | +
|
| 2236 | + The suffixes after the scaling factors represent the stages where they are being applied. |
| 2237 | +
|
| 2238 | + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values |
| 2239 | + that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. |
| 2240 | +
|
| 2241 | + Args: |
| 2242 | + s1 (`float`): |
| 2243 | + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to |
| 2244 | + mitigate "oversmoothing effect" in the enhanced denoising process. |
| 2245 | + s2 (`float`): |
| 2246 | + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to |
| 2247 | + mitigate "oversmoothing effect" in the enhanced denoising process. |
| 2248 | + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. |
| 2249 | + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. |
| 2250 | + """ |
| 2251 | + if not hasattr(self, "unet"): |
| 2252 | + raise ValueError("The pipeline must have `unet` for using FreeU.") |
| 2253 | + self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2) |
| 2254 | + |
| 2255 | + def disable_freeu(self): |
| 2256 | + """Disables the FreeU mechanism if enabled.""" |
| 2257 | + self.unet.disable_freeu() |
| 2258 | + |
| 2259 | + def fuse_qkv_projections(self, unet: bool = True, vae: bool = True): |
| 2260 | + """ |
| 2261 | + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) |
| 2262 | + are fused. For cross-attention modules, key and value projection matrices are fused. |
| 2263 | +
|
| 2264 | + > [!WARNING] > This API is 🧪 experimental. |
| 2265 | +
|
| 2266 | + Args: |
| 2267 | + unet (`bool`, defaults to `True`): To apply fusion on the UNet. |
| 2268 | + vae (`bool`, defaults to `True`): To apply fusion on the VAE. |
| 2269 | + """ |
| 2270 | + self.fusing_unet = False |
| 2271 | + self.fusing_vae = False |
| 2272 | + |
| 2273 | + if unet: |
| 2274 | + self.fusing_unet = True |
| 2275 | + self.unet.fuse_qkv_projections() |
| 2276 | + self.unet.set_attn_processor(FusedAttnProcessor2_0()) |
| 2277 | + |
| 2278 | + if vae: |
| 2279 | + if not isinstance(self.vae, AutoencoderKL): |
| 2280 | + raise ValueError("`fuse_qkv_projections()` is only supported for the VAE of type `AutoencoderKL`.") |
| 2281 | + |
| 2282 | + self.fusing_vae = True |
| 2283 | + self.vae.fuse_qkv_projections() |
| 2284 | + self.vae.set_attn_processor(FusedAttnProcessor2_0()) |
| 2285 | + |
| 2286 | + def unfuse_qkv_projections(self, unet: bool = True, vae: bool = True): |
| 2287 | + """Disable QKV projection fusion if enabled. |
| 2288 | +
|
| 2289 | + > [!WARNING] > This API is 🧪 experimental. |
| 2290 | +
|
| 2291 | + Args: |
| 2292 | + unet (`bool`, defaults to `True`): To apply fusion on the UNet. |
| 2293 | + vae (`bool`, defaults to `True`): To apply fusion on the VAE. |
| 2294 | +
|
| 2295 | + """ |
| 2296 | + if unet: |
| 2297 | + if not self.fusing_unet: |
| 2298 | + logger.warning("The UNet was not initially fused for QKV projections. Doing nothing.") |
| 2299 | + else: |
| 2300 | + self.unet.unfuse_qkv_projections() |
| 2301 | + self.fusing_unet = False |
| 2302 | + |
| 2303 | + if vae: |
| 2304 | + if not self.fusing_vae: |
| 2305 | + logger.warning("The VAE was not initially fused for QKV projections. Doing nothing.") |
| 2306 | + else: |
| 2307 | + self.vae.unfuse_qkv_projections() |
| 2308 | + self.fusing_vae = False |
0 commit comments