From 2e1920ef5dc9ed875dad05fd3f46ed4bc9393e0a Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Tue, 14 Jan 2025 20:20:27 +0800
Subject: [PATCH 1/7] add pipeline_stable_diffusion_xl_attentive_eraser
---
 examples/community/README.md                  |   86 +
 ...ne_stable_diffusion_xl_attentive_eraser.py | 2249 +++++++++++++++++
 2 files changed, 2335 insertions(+)
 mode change 100755 => 100644 examples/community/README.md
 create mode 100644 examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
diff --git a/examples/community/README.md b/examples/community/README.md
old mode 100755
new mode 100644
index c7c40c46ef2d..077363f11a2a
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
 PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
 | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
 | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3)|
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
@@ -4634,6 +4635,91 @@ make_image_grid(image, rows=1, cols=len(image))
 # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
 ```
 
+### Stable Diffusion XL Attentive Eraser Pipeline
+ +
+**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
+
+#### Key Features
+
+- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
+- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
+- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
+
+#### How to Use
+To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
+```py
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers.utils import load_image
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+dtype = torch.float16
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
+
+scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+pipeline = DiffusionPipeline.from_pretrained(
+    "stabilityai/stable-diffusion-xl-base-1.0",
+    custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+    scheduler=scheduler,
+    variant="fp16",
+    use_safetensors=True,
+    torch_dtype=dtype,
+).to(device)
+
+
+def preprocess_image(image_path, device):
+    image = to_tensor((load_image(image_path)))
+    image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+    if image.shape[1] != 3:
+        image = image.expand(-1, 3, -1, -1)
+        image = F.interpolate(image, (1024, 1024))
+        image = image.to(dtype).to(device)
+        return image
+
+def preprocess_mask(mask_path, device):
+    mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+    mask = mask.unsqueeze_(0).float()  # 0 or 1
+    mask = F.interpolate(mask, (1024, 1024))
+    mask = gaussian_blur(mask, kernel_size=(77, 77))
+    mask[mask < 0.1] = 0
+    mask[mask >= 0.1] = 1
+    mask = mask.to(dtype).to(device)
+    return mask
+
+prompt = "" # Set prompt to null
+seed=123 
+generator = torch.Generator(device=device).manual_seed(seed)
+source_image_path = "./path-to-image.png"
+mask_path = "./path-to-mask.png"
+source_image = preprocess_image(source_image_path, device)
+mask = preprocess_mask(mask_path, device)
+
+image = pipeline(
+    prompt=prompt, 
+    image=source_image,
+    mask_image=mask,
+    height=1024,
+    width=1024,
+    AAS=True, # enable AAS
+    strength=0.8, # inpainting strength
+    rm_guidance_scale=9, # removal guidance scale
+    ss_steps = 9, # similarity suppression steps
+    ss_scale = 0.3, # similarity suppression scale
+    AAS_start_step=0, # AAS start step
+    AAS_start_layer=34, # AAS start layer
+    AAS_end_layer=70, # AAS end layer
+    num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+    generator=generator,
+    guidance_scale=1,
+).images[0]
+image.save('./removed_img.png')
+print("Object removal completed")
+```
+
+
+
 # Perturbed-Attention Guidance
 
 [Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
new file mode 100644
index 000000000000..cebd80c8bfa1
--- /dev/null
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -0,0 +1,2249 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from tqdm import tqdm
+import numpy as np
+import PIL.Image
+from PIL import Image
+import torch
+from transformers import (
+    CLIPImageProcessor,
+    CLIPTextModel,
+    CLIPTextModelWithProjection,
+    CLIPTokenizer,
+    CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+    FromSingleFileMixin,
+    IPAdapterMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+    AttnProcessor2_0,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    deprecate,
+    is_invisible_watermark_available,
+    is_torch_xla_available,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+
+    XLA_AVAILABLE = True
+else:
+    XLA_AVAILABLE = False
+
+from einops import rearrange, repeat
+import torch.nn as nn
+import torch.nn.functional as F
+import os
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> import torch
+        >>> from diffusers import StableDiffusionXLInpaintPipeline
+        >>> from diffusers.utils import load_image
+
+        >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+        ...     "stabilityai/stable-diffusion-xl-base-1.0",
+        ...     torch_dtype=torch.float16,
+        ...     variant="fp16",
+        ...     use_safetensors=True,
+        ... )
+        >>> pipe.to("cuda")
+
+        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+        >>> init_image = load_image(img_url).convert("RGB")
+        >>> mask_image = load_image(mask_url).convert("RGB")
+
+        >>> prompt = "A majestic tiger sitting on a bench"
+        >>> image = pipe(
+        ...     prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+        ... ).images[0]
+        ```
+"""
+
+class AttentionBase:
+    def __init__(self):
+        self.cur_step = 0
+        self.num_att_layers = -1
+        self.cur_att_layer = 0
+
+    def after_step(self):
+        pass
+
+    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+        self.cur_att_layer += 1
+        if self.cur_att_layer == self.num_att_layers:
+            self.cur_att_layer = 0
+            self.cur_step += 1
+            # after step
+            self.after_step()
+        return out
+
+    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        out = torch.einsum('b i j, b j d -> b i d', attn, v)
+        out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
+        return out
+
+    def reset(self):
+        self.cur_step = 0
+        self.cur_att_layer = 0
+
+class AAS_XL(AttentionBase):
+    MODEL_TYPE = {
+        "SD": 16,
+        "SDXL": 70
+    }
+    def __init__(self,  start_step=4, end_step= 50, start_layer=10, end_layer=16,layer_idx=None, step_idx=None, total_steps=50,  mask=None, model_type="SD",ss_steps=9,ss_scale=1.0):
+        """
+        Args:
+            start_step: the step to start AAS
+            start_layer: the layer to start AAS
+            layer_idx: list of the layers to apply AAS
+            step_idx: list the steps to apply AAS
+            total_steps: the total number of steps
+            mask: source mask with shape (h, w)
+            model_type: the model type, SD or SDXL
+        """
+        super().__init__()
+        self.total_steps = total_steps
+        self.total_layers = self.MODEL_TYPE.get(model_type, 16)
+        self.start_step = start_step
+        self.end_step = end_step
+        self.start_layer = start_layer
+        self.end_layer = end_layer
+        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
+        self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+        self.mask = mask  # mask with shape (1, 1 ,h, w)
+        self.ss_steps = ss_steps
+        self.ss_scale = ss_scale
+        print("AAS at denoising steps: ", self.step_idx)
+        print("AAS at U-Net layers: ", self.layer_idx)
+        print("start AAS")
+        self.mask_16 = F.max_pool2d(mask,(1024//16,1024//16)).round().squeeze().squeeze()
+        self.mask_32 = F.max_pool2d(mask,(1024//32,1024//32)).round().squeeze().squeeze()
+        self.mask_64 = F.max_pool2d(mask,(1024//64,1024//64)).round().squeeze().squeeze()
+        self.mask_128 = F.max_pool2d(mask,(1024//128,1024//128)).round().squeeze().squeeze()
+    
+    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,is_mask_attn, mask, **kwargs):
+        B = q.shape[0] // num_heads
+        if is_mask_attn:
+            mask_flatten = mask.flatten(0)
+            if self.cur_step <= self.ss_steps:                                                                                                                                                                                                                                           
+                # background
+                sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min) 
+
+                #object
+                sim_fg = self.ss_scale*sim
+                sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+                sim = torch.cat([sim_fg, sim_bg], dim=0)
+            else:
+                sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+        attn = sim.softmax(-1)
+        if len(attn) == 2 * len(v):
+            v = torch.cat([v] * 2)
+        out = torch.einsum("h i j, h j d -> h i d", attn, v)
+        out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+        return out
+
+    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        """
+        Attention forward function
+        """
+        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+        B = q.shape[0] // num_heads // 2
+        H = W = int(np.sqrt(q.shape[1]))
+        if H == 16:
+            mask = self.mask_16.to(sim.device)
+        elif H == 32:
+            mask = self.mask_32.to(sim.device)
+        elif H == 64:
+            mask = self.mask_64.to(sim.device)
+        else:
+            mask = self.mask_128.to(sim.device)
+
+
+        q_wo, q_w = q.chunk(2)
+        k_wo, k_w = k.chunk(2)
+        v_wo, v_w = v.chunk(2)
+        sim_wo, sim_w = sim.chunk(2)
+        attn_wo, attn_w = attn.chunk(2)
+
+        out_source = self.attn_batch(q_wo, k_wo, v_wo, sim_wo, attn_wo, is_cross, place_in_unet, num_heads,is_mask_attn=False,mask=None,**kwargs)
+        out_target = self.attn_batch(q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask = mask,**kwargs)
+
+        if self.mask is not None:
+            if out_target.shape[0] == 2:
+                out_target_fg, out_target_bg = out_target.chunk(2, 0)
+                mask = mask.reshape(-1, 1)  # (hw, 1)
+                out_target = out_target_fg * mask + out_target_bg * (1 - mask)
+            else:
+                out_target = out_target
+        
+        out = torch.cat([out_source, out_target], dim=0)
+        return out
+    
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+    """
+    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+    """
+    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+    # rescale the results from guidance (fixes overexposure)
+    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+    return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+    # preprocess mask
+    if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+        mask = [mask]
+
+    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+        mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+        mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+        mask = mask.astype(np.float32) / 255.0
+    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+    mask = torch.from_numpy(mask)
+    return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+    """
+    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+    ``image`` and ``1`` for the ``mask``.
+
+    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+    Args:
+        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+    Raises:
+        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+            (ot the other way around).
+
+    Returns:
+        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+            dimensions: ``batch x channels x height x width``.
+    """
+
+    # checkpoint. TOD(Yiyi) - need to clean this up later
+    deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+    deprecate(
+        "prepare_mask_and_masked_image",
+        "0.30.0",
+        deprecation_message,
+    )
+    if image is None:
+        raise ValueError("`image` input cannot be undefined.")
+
+    if mask is None:
+        raise ValueError("`mask_image` input cannot be undefined.")
+
+    if isinstance(image, torch.Tensor):
+        if not isinstance(mask, torch.Tensor):
+            mask = mask_pil_to_torch(mask, height, width)
+
+        if image.ndim == 3:
+            image = image.unsqueeze(0)
+
+        # Batch and add channel dim for single mask
+        if mask.ndim == 2:
+            mask = mask.unsqueeze(0).unsqueeze(0)
+
+        # Batch single mask or add channel dim
+        if mask.ndim == 3:
+            # Single batched mask, no channel dim or single mask not batched but channel dim
+            if mask.shape[0] == 1:
+                mask = mask.unsqueeze(0)
+
+            # Batched masks no channel dim
+            else:
+                mask = mask.unsqueeze(1)
+
+        assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+        # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+        assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+        # Check image is in [-1, 1]
+        # if image.min() < -1 or image.max() > 1:
+        #    raise ValueError("Image should be in [-1, 1] range")
+
+        # Check mask is in [0, 1]
+        if mask.min() < 0 or mask.max() > 1:
+            raise ValueError("Mask should be in [0, 1] range")
+
+        # Binarize mask
+        mask[mask < 0.5] = 0
+        mask[mask >= 0.5] = 1
+
+        # Image as float32
+        image = image.to(dtype=torch.float32)
+    elif isinstance(mask, torch.Tensor):
+        raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+    else:
+        # preprocess image
+        if isinstance(image, (PIL.Image.Image, np.ndarray)):
+            image = [image]
+        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+            # resize all images w.r.t passed height an width
+            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+            image = [np.array(i.convert("RGB"))[None, :] for i in image]
+            image = np.concatenate(image, axis=0)
+        elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+            image = np.concatenate([i[None, :] for i in image], axis=0)
+
+        image = image.transpose(0, 3, 1, 2)
+        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+        mask = mask_pil_to_torch(mask, height, width)
+        mask[mask < 0.5] = 0
+        mask[mask >= 0.5] = 1
+
+    if image.shape[1] == 4:
+        # images are in latent space and thus can't
+        # be masked set masked_image to None
+        # we assume that the checkpoint is not an inpainting
+        # checkpoint. TOD(Yiyi) - need to clean this up later
+        masked_image = None
+    else:
+        masked_image = image * (mask < 0.5)
+
+    # n.b. ensure backwards compatibility as old function does not return image
+    if return_image:
+        return mask, masked_image, image
+
+    return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+        return encoder_output.latent_dist.sample(generator)
+    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+        return encoder_output.latent_dist.mode()
+    elif hasattr(encoder_output, "latents"):
+        return encoder_output.latents
+    else:
+        raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used,
+            `timesteps` must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+                must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+class StableDiffusionXL_AE_Pipeline(
+    DiffusionPipeline,
+    StableDiffusionMixin,
+    TextualInversionLoaderMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    FromSingleFileMixin,
+    IPAdapterMixin,
+):
+    r"""
+    Pipeline for object removal using Stable Diffusion XL.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    The pipeline also inherits the following loading methods:
+        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion XL uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        text_encoder_2 ([` CLIPTextModelWithProjection`]):
+            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+            specifically the
+            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+            variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        tokenizer_2 (`CLIPTokenizer`):
+            Second Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+            Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+            of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+            `stabilityai/stable-diffusion-xl-base-1-0`.
+        add_watermarker (`bool`, *optional*):
+            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+            watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+            watermarker will be used.
+    """
+
+    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+
+    _optional_components = [
+        "tokenizer",
+        "tokenizer_2",
+        "text_encoder",
+        "text_encoder_2",
+        "image_encoder",
+        "feature_extractor",
+    ]
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+        "add_text_embeds",
+        "add_time_ids",
+        "negative_pooled_prompt_embeds",
+        "add_neg_time_ids",
+        "mask",
+        "masked_image_latents",
+    ]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        text_encoder_2: CLIPTextModelWithProjection,
+        tokenizer: CLIPTokenizer,
+        tokenizer_2: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        scheduler: KarrasDiffusionSchedulers,
+        image_encoder: CLIPVisionModelWithProjection = None,
+        feature_extractor: CLIPImageProcessor = None,
+        requires_aesthetics_score: bool = False,
+        force_zeros_for_empty_prompt: bool = True,
+        add_watermarker: Optional[bool] = None,
+    ):
+        super().__init__()
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            text_encoder_2=text_encoder_2,
+            tokenizer=tokenizer,
+            tokenizer_2=tokenizer_2,
+            unet=unet,
+            image_encoder=image_encoder,
+            feature_extractor=feature_extractor,
+            scheduler=scheduler,
+        )
+        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.mask_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+        )
+
+        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+        if add_watermarker:
+            self.watermark = StableDiffusionXLWatermarker()
+        else:
+            self.watermark = None
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        if not isinstance(image, torch.Tensor):
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+        if output_hidden_states:
+            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_enc_hidden_states = self.image_encoder(
+                torch.zeros_like(image), output_hidden_states=True
+            ).hidden_states[-2]
+            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
+            return image_enc_hidden_states, uncond_image_enc_hidden_states
+        else:
+            image_embeds = self.image_encoder(image).image_embeds
+            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_embeds = torch.zeros_like(image_embeds)
+
+            return image_embeds, uncond_image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+    ):
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+                )
+
+            image_embeds = []
+            for single_ip_adapter_image, image_proj_layer in zip(
+                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+            ):
+                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+                single_image_embeds, single_negative_image_embeds = self.encode_image(
+                    single_ip_adapter_image, device, 1, output_hidden_state
+                )
+                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+                single_negative_image_embeds = torch.stack(
+                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
+                )
+
+                if do_classifier_free_guidance:
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                    single_image_embeds = single_image_embeds.to(device)
+
+                image_embeds.append(single_image_embeds)
+        else:
+            repeat_dims = [1]
+            image_embeds = []
+            for single_image_embeds in ip_adapter_image_embeds:
+                if do_classifier_free_guidance:
+                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                    single_negative_image_embeds = single_negative_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+                    )
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                else:
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                image_embeds.append(single_image_embeds)
+
+        return image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+    def encode_prompt(
+        self,
+        prompt: str,
+        prompt_2: Optional[str] = None,
+        device: Optional[torch.device] = None,
+        num_images_per_prompt: int = 1,
+        do_classifier_free_guidance: bool = True,
+        negative_prompt: Optional[str] = None,
+        negative_prompt_2: Optional[str] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        clip_skip: Optional[int] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            lora_scale (`float`, *optional*):
+                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+        """
+        device = device or self._execution_device
+
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+            # dynamically adjust the LoRA scale
+            if self.text_encoder is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder, lora_scale)
+
+            if self.text_encoder_2 is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder_2, lora_scale)
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        if prompt_embeds is None:
+            prompt_2 = prompt_2 or prompt
+            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+            # textual inversion: process multi-vector tokens if necessary
+            prompt_embeds_list = []
+            prompts = [prompt, prompt_2]
+            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+                text_inputs = tokenizer(
+                    prompt,
+                    padding="max_length",
+                    max_length=tokenizer.model_max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                text_input_ids = text_inputs.input_ids
+                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                    text_input_ids, untruncated_ids
+                ):
+                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                    logger.warning(
+                        "The following part of your input was truncated because CLIP can only handle sequences up to"
+                        f" {tokenizer.model_max_length} tokens: {removed_text}"
+                    )
+
+                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                pooled_prompt_embeds = prompt_embeds[0]
+                if clip_skip is None:
+                    prompt_embeds = prompt_embeds.hidden_states[-2]
+                else:
+                    # "2" because SDXL always indexes from the penultimate layer.
+                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+                prompt_embeds_list.append(prompt_embeds)
+
+            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+        # get unconditional embeddings for classifier free guidance
+        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+            negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+        elif do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+            # normalize str to list
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+            negative_prompt_2 = (
+                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+            )
+
+            uncond_tokens: List[str]
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = [negative_prompt, negative_prompt_2]
+
+            negative_prompt_embeds_list = []
+            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+                max_length = prompt_embeds.shape[1]
+                uncond_input = tokenizer(
+                    negative_prompt,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                negative_prompt_embeds = text_encoder(
+                    uncond_input.input_ids.to(device),
+                    output_hidden_states=True,
+                )
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+                negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+        if self.text_encoder_2 is not None:
+            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+        else:
+            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            if self.text_encoder_2 is not None:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+            else:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+        if do_classifier_free_guidance:
+            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+                bs_embed * num_images_per_prompt, -1
+            )
+
+        if self.text_encoder is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder, lora_scale)
+
+        if self.text_encoder_2 is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def check_inputs(
+        self,
+        prompt,
+        prompt_2,
+        image,
+        mask_image,
+        height,
+        width,
+        strength,
+        callback_steps,
+        output_type,
+        negative_prompt=None,
+        negative_prompt_2=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        ip_adapter_image=None,
+        ip_adapter_image_embeds=None,
+        callback_on_step_end_tensor_inputs=None,
+        padding_mask_crop=None,
+    ):
+        if strength < 0 or strength > 1:
+            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt_2 is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+            raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+        if padding_mask_crop is not None:
+            if not isinstance(image, PIL.Image.Image):
+                raise ValueError(
+                    f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+                )
+            if not isinstance(mask_image, PIL.Image.Image):
+                raise ValueError(
+                    f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+                    f" {type(mask_image)}."
+                )
+            if output_type != "pil":
+                raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+            raise ValueError(
+                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+            )
+
+        if ip_adapter_image_embeds is not None:
+            if not isinstance(ip_adapter_image_embeds, list):
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+                )
+            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+                )
+
+    def prepare_latents(
+        self,
+        batch_size,
+        num_channels_latents,
+        height,
+        width,
+        dtype,
+        device,
+        generator,
+        latents=None,
+        image=None,
+        timestep=None,
+        is_strength_max=True,
+        add_noise=True,
+        return_noise=False,
+        return_image_latents=False,
+    ):
+        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if (image is None or timestep is None) and not is_strength_max:
+            raise ValueError(
+                "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+                "However, either the image or the noise timestep has not been provided."
+            )
+
+        if image.shape[1] == 4:
+            image_latents = image.to(device=device, dtype=dtype)
+            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+        elif return_image_latents or (latents is None and not is_strength_max):
+            image = image.to(device=device, dtype=dtype)
+            image_latents = self._encode_vae_image(image=image, generator=generator)
+            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+        if latents is None and add_noise:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            # if strength is 1. then initialise the latents to noise, else initial to image + noise
+            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+            # if pure noise then scale the initial latents by the  Scheduler's init sigma
+            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+        elif add_noise:
+            noise = latents.to(device)
+            latents = noise * self.scheduler.init_noise_sigma
+        else:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            latents = image_latents.to(device)
+
+        outputs = (latents,)
+
+        if return_noise:
+            outputs += (noise,)
+
+        if return_image_latents:
+            outputs += (image_latents,)
+
+        return outputs
+
+    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+        dtype = image.dtype
+        if self.vae.config.force_upcast:
+            image = image.float()
+            self.vae.to(dtype=torch.float32)
+
+        if isinstance(generator, list):
+            image_latents = [
+                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+                for i in range(image.shape[0])
+            ]
+            image_latents = torch.cat(image_latents, dim=0)
+        else:
+            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+        if self.vae.config.force_upcast:
+            self.vae.to(dtype)
+
+        image_latents = image_latents.to(dtype)
+        image_latents = self.vae.config.scaling_factor * image_latents
+
+        return image_latents
+
+    def prepare_mask_latents(
+        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+    ):
+        # resize the mask to latents shape as we concatenate the mask to the latents
+        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+        # and half precision
+        #mask = torch.nn.functional.interpolate(
+        #    mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+        #)
+        mask = torch.nn.functional.max_pool2d(mask, (8,8)).round()
+        mask = mask.to(device=device, dtype=dtype)
+
+        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+        if mask.shape[0] < batch_size:
+            if not batch_size % mask.shape[0] == 0:
+                raise ValueError(
+                    "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+                    f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+                    " of masks that you pass is divisible by the total requested batch size."
+                )
+            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+        if masked_image is not None and masked_image.shape[1] == 4:
+            masked_image_latents = masked_image
+        else:
+            masked_image_latents = None
+
+        if masked_image is not None:
+            if masked_image_latents is None:
+                masked_image = masked_image.to(device=device, dtype=dtype)
+                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+            if masked_image_latents.shape[0] < batch_size:
+                if not batch_size % masked_image_latents.shape[0] == 0:
+                    raise ValueError(
+                        "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+                        f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+                        " Make sure the number of images that you pass is divisible by the total requested batch size."
+                    )
+                masked_image_latents = masked_image_latents.repeat(
+                    batch_size // masked_image_latents.shape[0], 1, 1, 1
+                )
+
+            masked_image_latents = (
+                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+            )
+
+            # aligning device to prevent device errors when concating it with the latent model input
+            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+        return mask, masked_image_latents
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+        # get the original timestep using init_timestep
+        if denoising_start is None:
+            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+            t_start = max(num_inference_steps - init_timestep, 0)
+        else:
+            t_start = 0
+
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+        # Strength is irrelevant if we directly request a timestep to start at;
+        # that is, strength is determined by the denoising_start instead.
+        if denoising_start is not None:
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (denoising_start * self.scheduler.config.num_train_timesteps)
+                )
+            )
+
+            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+                # if the scheduler is a 2nd order scheduler we might have to do +1
+                # because `num_inference_steps` might be even given that every timestep
+                # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+                # mean that we cut the timesteps in the middle of the denoising step
+                # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+                num_inference_steps = num_inference_steps + 1
+
+            # because t_n+1 >= t_n, we slice the timesteps starting from the end
+            timesteps = timesteps[-num_inference_steps:]
+            return timesteps, num_inference_steps
+
+        return timesteps, num_inference_steps - t_start
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+    def _get_add_time_ids(
+        self,
+        original_size,
+        crops_coords_top_left,
+        target_size,
+        aesthetic_score,
+        negative_aesthetic_score,
+        negative_original_size,
+        negative_crops_coords_top_left,
+        negative_target_size,
+        dtype,
+        text_encoder_projection_dim=None,
+    ):
+        if self.config.requires_aesthetics_score:
+            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+            add_neg_time_ids = list(
+                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+            )
+        else:
+            add_time_ids = list(original_size + crops_coords_top_left + target_size)
+            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+        passed_add_embed_dim = (
+            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+        )
+        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+        if (
+            expected_add_embed_dim > passed_add_embed_dim
+            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+            )
+        elif (
+            expected_add_embed_dim < passed_add_embed_dim
+            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+            )
+        elif expected_add_embed_dim != passed_add_embed_dim:
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+            )
+
+        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+        return add_time_ids, add_neg_time_ids
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+    def upcast_vae(self):
+        dtype = self.vae.dtype
+        self.vae.to(dtype=torch.float32)
+        use_torch_2_0_or_xformers = isinstance(
+            self.vae.decoder.mid_block.attentions[0].processor,
+            (
+                AttnProcessor2_0,
+                XFormersAttnProcessor,
+                LoRAXFormersAttnProcessor,
+                LoRAAttnProcessor2_0,
+            ),
+        )
+        # if xformers or torch_2_0 is used attention block does not need
+        # to be in float32 which can save lots of memory
+        if use_torch_2_0_or_xformers:
+            self.vae.post_quant_conv.to(dtype)
+            self.vae.decoder.conv_in.to(dtype)
+            self.vae.decoder.mid_block.to(dtype)
+
+    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+        """
+        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+        Args:
+            timesteps (`torch.Tensor`):
+                generate embedding vectors at these timesteps
+            embedding_dim (`int`, *optional*, defaults to 512):
+                dimension of the embeddings to generate
+            dtype:
+                data type of the generated embeddings
+
+        Returns:
+            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+        """
+        assert len(w.shape) == 1
+        w = w * 1000.0
+
+        half_dim = embedding_dim // 2
+        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+        emb = w.to(dtype)[:, None] * emb[None, :]
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+        if embedding_dim % 2 == 1:  # zero pad
+            emb = torch.nn.functional.pad(emb, (0, 1))
+        assert emb.shape == (w.shape[0], embedding_dim)
+        return emb
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def guidance_rescale(self):
+        return self._guidance_rescale
+
+    @property
+    def clip_skip(self):
+        return self._clip_skip
+
+
+    @property
+    def do_self_attention_redirection_guidance(self): #SARG
+        return self._rm_guidance_scale > 1 and self._AAS
+    
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None and self.do_self_attention_redirection_guidance==False #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+    
+    @property
+    def cross_attention_kwargs(self):
+        return self._cross_attention_kwargs
+
+    @property
+    def denoising_end(self):
+        return self._denoising_end
+
+    @property
+    def denoising_start(self):
+        return self._denoising_start
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    def image2latent(self, image: torch.Tensor, generator: torch.Generator):
+        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        if type(image) is Image:
+            image = np.array(image)
+            image = torch.from_numpy(image).float() / 127.5 - 1
+            image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
+        # input image density range [-1, 1]
+        #latents = self.vae.encode(image)['latent_dist'].mean
+        latents = self._encode_vae_image(image, generator)
+        #latents = retrieve_latents(self.vae.encode(image))
+        #latents = latents * self.vae.config.scaling_factor
+        return latents
+    
+    def next_step(
+        self,
+        model_output: torch.FloatTensor,
+        timestep: int,
+        x: torch.FloatTensor,
+        eta=0.,
+        verbose=False
+    ):
+        """
+        Inverse sampling for DDIM Inversion
+        """
+        if verbose:
+            print("timestep: ", timestep)
+        next_step = timestep
+        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
+        beta_prod_t = 1 - alpha_prod_t
+        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+        pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
+        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
+        return x_next, pred_x0
+    
+    @torch.no_grad()
+    def invert(
+        self,
+        image: torch.Tensor,
+        prompt,
+        num_inference_steps=50,
+        eta=0.0,
+        original_size: Tuple[int, int] = None,
+        target_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        return_intermediates=False,
+        **kwds):
+        """
+        invert a real image into noise map with determinisc DDIM inversion
+        """
+        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        batch_size = image.shape[0]
+        if isinstance(prompt, list):
+            if batch_size == 1:
+                image = image.expand(len(prompt), -1, -1, -1)
+        elif isinstance(prompt, str):
+            if batch_size > 1:
+                prompt = [prompt] * batch_size
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        prompt_2 = prompt
+        prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+        # textual inversion: process multi-vector tokens if necessary
+        prompt_embeds_list = []
+        prompts = [prompt, prompt_2]
+        for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+            if isinstance(self, TextualInversionLoaderMixin):
+                prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+            text_inputs = tokenizer(
+                prompt,
+                padding="max_length",
+                max_length=tokenizer.model_max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+
+            text_input_ids = text_inputs.input_ids
+            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                text_input_ids, untruncated_ids
+            ):
+                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                logger.warning(
+                    "The following part of your input was truncated because CLIP can only handle sequences up to"
+                    f" {tokenizer.model_max_length} tokens: {removed_text}"
+                )
+
+            prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+
+            # We are only ALWAYS interested in the pooled output of the final text encoder
+            pooled_prompt_embeds = prompt_embeds[0]
+            prompt_embeds = prompt_embeds.hidden_states[-2]
+            prompt_embeds_list.append(prompt_embeds)
+        prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+        prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
+
+        # define initial latents
+        latents = self.image2latent(image,generator=None)
+
+        start_latents = latents
+        height, width = latents.shape[-2:]
+        height = height * self.vae_scale_factor
+        width = width * self.vae_scale_factor
+
+        original_size = (height, width)
+        target_size = (height, width)
+        negative_original_size = original_size
+        negative_target_size = target_size
+
+        add_text_embeds = pooled_prompt_embeds
+        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+
+        add_time_ids = add_time_ids.repeat(batch_size, 1).to(DEVICE)
+
+        # interative sampling
+        self.scheduler.set_timesteps(num_inference_steps)
+        latents_list = [latents]
+        pred_x0_list = []
+        #for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+        for i, t in enumerate(reversed(self.scheduler.timesteps)):
+            model_inputs = latents
+
+            # predict the noise
+            added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+            noise_pred = self.unet(model_inputs, t, encoder_hidden_states=prompt_embeds,added_cond_kwargs=added_cond_kwargs).sample
+
+            # compute the previous noise sample x_t-1 -> x_t
+            latents, pred_x0 = self.next_step(noise_pred, t, latents)
+            """             
+            if t >= 1 and t < 41:
+                latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
+            else:
+                latents, pred_x0 = self.next_step(noise_pred, t, latents) """
+            
+            latents_list.append(latents)
+            pred_x0_list.append(pred_x0)
+
+        if return_intermediates:
+            # return the intermediate laters during inversion
+            #pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+            #latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+            return latents, latents_list, pred_x0_list
+        return latents, start_latents
+    
+    def opt(
+        self,
+        model_output: torch.FloatTensor,
+        timestep: int,
+        x: torch.FloatTensor,
+    ):
+        """
+        predict the sampe the next step in the denoise process.
+        """
+        ref_noise = model_output[:1,:,:,:].expand(model_output.shape)
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+        beta_prod_t = 1 - alpha_prod_t
+        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+        x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t)**0.5 * ref_noise
+        return x_opt, pred_x0
+    
+    def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
+        """
+        Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
+        """
+        def ca_forward(self, place_in_unet):
+            def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+                """
+                The attention is similar to the original implementation of LDM CrossAttention class
+                except adding some modifications on the attention
+                """
+                if encoder_hidden_states is not None:
+                    context = encoder_hidden_states
+                if attention_mask is not None:
+                    mask = attention_mask
+
+                to_out = self.to_out
+                if isinstance(to_out, nn.modules.container.ModuleList):
+                    to_out = self.to_out[0]
+                else:
+                    to_out = self.to_out
+
+                h = self.heads
+                q = self.to_q(x)
+                is_cross = context is not None
+                context = context if is_cross else x
+                k = self.to_k(context)
+                v = self.to_v(context)
+                q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+                sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+                if mask is not None:
+                    mask = rearrange(mask, 'b ... -> b (...)')
+                    max_neg_value = -torch.finfo(sim.dtype).max
+                    mask = repeat(mask, 'b j -> (b h) () j', h=h)
+                    mask = mask[:, None, :].repeat(h, 1, 1)
+                    sim.masked_fill_(~mask, max_neg_value)
+
+                attn = sim.softmax(dim=-1)
+                # the only difference
+                out = editor(
+                    q, k, v, sim, attn, is_cross, place_in_unet,
+                    self.heads, scale=self.scale)
+
+                return to_out(out)
+
+            return forward
+
+        def register_editor(net, count, place_in_unet):
+            for name, subnet in net.named_children():
+                if net.__class__.__name__ == 'Attention':  # spatial Transformer layer
+                    net.forward = ca_forward(net, place_in_unet)
+                    return count + 1
+                elif hasattr(net, 'children'):
+                    count = register_editor(subnet, count, place_in_unet)
+            return count
+
+        cross_att_count = 0
+        for net_name, net in unet.named_children():
+            if "down" in net_name:
+                cross_att_count += register_editor(net, 0, "down")
+            elif "mid" in net_name:
+                cross_att_count += register_editor(net, 0, "mid")
+            elif "up" in net_name:
+                cross_att_count += register_editor(net, 0, "up")
+        editor.num_att_layers = cross_att_count
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        prompt_2: Optional[Union[str, List[str]]] = None,
+        image: PipelineImageInput = None,
+        mask_image: PipelineImageInput = None,
+        masked_image_latents: torch.FloatTensor = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        padding_mask_crop: Optional[int] = None,
+        strength: float = 0.9999,
+        AAS: bool = True,  # AE parameter
+        rm_guidance_scale: float = 7.0, # AE parameter
+        ss_steps: int = 9,  # AE parameter
+        ss_scale: float = 0.3, # AE parameter
+        AAS_start_step: int = 0, # AE parameter
+        AAS_start_layer: int = 34, # AE parameter
+        AAS_end_layer: int = 70, # AE parameter
+        num_inference_steps: int = 50,
+        timesteps: List[int] = None,
+        denoising_start: Optional[float] = None,
+        denoising_end: Optional[float] = None,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt_2: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        guidance_rescale: float = 0.0,
+        original_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Tuple[int, int] = None,
+        negative_original_size: Optional[Tuple[int, int]] = None,
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_target_size: Optional[Tuple[int, int]] = None,
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        **kwargs,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            image (`PIL.Image.Image`):
+                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+                be masked out with `mask_image` and repainted according to `prompt`.
+            mask_image (`PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+                instead of 3, so the expected shape would be `(B, H, W, 1)`.
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image. This is set to 1024 by default for the best results.
+                Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image. This is set to 1024 by default for the best results.
+                Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            padding_mask_crop (`int`, *optional*, defaults to `None`):
+                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+                `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+                contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+                the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+                and contain information inreleant for inpainging, such as background.
+            strength (`float`, *optional*, defaults to 0.9999):
+                Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+                between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+                `strength`. The number of denoising steps depends on the amount of noise initially added. When
+                `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+                iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+                portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+                integer, the value of `strength` will be ignored.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            denoising_start (`float`, *optional*):
+                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+                is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+            denoising_end (`float`, *optional*):
+                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+                completed before it is intentionally prematurely terminated. As a result, the returned sample will
+                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+                forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+                if `do_classifier_free_guidance` is set to `True`.
+                If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+                explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                For most cases, `target_size` should be set to the desired height and width of the generated image. If
+                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a target image resolution. It should be as same
+                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            aesthetic_score (`float`, *optional*, defaults to 6.0):
+                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+                simulate an aesthetic score of the generated image by influencing the negative text condition.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+            )
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 1. Check inputs
+        self.check_inputs(
+            prompt,
+            prompt_2,
+            image,
+            mask_image,
+            height,
+            width,
+            strength,
+            callback_steps,
+            output_type,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds,
+            negative_prompt_embeds,
+            ip_adapter_image,
+            ip_adapter_image_embeds,
+            callback_on_step_end_tensor_inputs,
+            padding_mask_crop,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._guidance_rescale = guidance_rescale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+        self._denoising_end = denoising_end
+        self._denoising_start = denoising_start
+        self._interrupt = False
+
+        ########### AE parameters
+        self._num_timesteps = num_inference_steps
+        self._rm_guidance_scale = rm_guidance_scale
+        self._AAS = AAS
+        self._ss_steps = ss_steps
+        self._ss_scale = ss_scale
+        self._AAS_start_step = AAS_start_step
+        self._AAS_start_layer = AAS_start_layer
+        self._AAS_end_layer = AAS_end_layer
+        ###########
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # 3. Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            prompt_2=prompt_2,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=self.do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            negative_prompt_2=negative_prompt_2,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        # 4. set timesteps
+        def denoising_value_valid(dnv):
+            return isinstance(dnv, float) and 0 < dnv < 1
+
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        timesteps, num_inference_steps = self.get_timesteps(
+            num_inference_steps,
+            strength,
+            device,
+            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+        )
+        # check that number of inference steps is not < 1 - as this doesn't make sense
+        if num_inference_steps < 1:
+            raise ValueError(
+                f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+                f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+            )
+        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+        is_strength_max = strength == 1.0
+
+        # 5. Preprocess mask and image
+        if padding_mask_crop is not None:
+            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+            resize_mode = "fill"
+        else:
+            crops_coords = None
+            resize_mode = "default"
+
+        original_image = image
+        init_image = self.image_processor.preprocess(
+            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+        )
+        init_image = init_image.to(dtype=torch.float32)
+
+        mask = self.mask_processor.preprocess(
+            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+        )
+
+        if masked_image_latents is not None:
+            masked_image = masked_image_latents
+        elif init_image.shape[1] == 4:
+            # if images are in latent space, we can't mask it
+            masked_image = None
+        else:
+            masked_image = init_image * (mask < 0.5)
+
+        # 6. Prepare latent variables
+        num_channels_latents = self.vae.config.latent_channels
+        num_channels_unet = self.unet.config.in_channels
+        return_image_latents = num_channels_unet == 4
+
+        add_noise = True if self.denoising_start is None else False
+        latents_outputs = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+            image=init_image,
+            timestep=latent_timestep,
+            is_strength_max=is_strength_max,
+            add_noise=add_noise,
+            return_noise=True,
+            return_image_latents=return_image_latents,
+        )
+
+        if return_image_latents:
+            latents, noise, image_latents = latents_outputs
+        else:
+            latents, noise = latents_outputs
+
+        # 7. Prepare mask latent variables
+        mask, masked_image_latents = self.prepare_mask_latents(
+            mask,
+            masked_image,
+            batch_size * num_images_per_prompt,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            self.do_classifier_free_guidance,
+        )
+
+        # 8. Check that sizes of mask, masked image and latents match
+        if num_channels_unet == 9:
+            # default case for runwayml/stable-diffusion-inpainting
+            num_channels_mask = mask.shape[1]
+            num_channels_masked_image = masked_image_latents.shape[1]
+            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+                raise ValueError(
+                    f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+                    f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+                    f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+                    f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+                    " `pipeline.unet` or your `mask_image` or `image` input."
+                )
+        elif num_channels_unet != 4:
+            raise ValueError(
+                f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+            )
+        # 8.1 Prepare extra step kwargs.
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        height, width = latents.shape[-2:]
+        height = height * self.vae_scale_factor
+        width = width * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        # 10. Prepare added time ids & embeddings
+        if negative_original_size is None:
+            negative_original_size = original_size
+        if negative_target_size is None:
+            negative_target_size = target_size
+
+        add_text_embeds = pooled_prompt_embeds
+        if self.text_encoder_2 is None:
+            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        else:
+            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+
+        ###########
+        if self.do_self_attention_redirection_guidance:
+            prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+        ############
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device)
+
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+                self.do_classifier_free_guidance,
+            )
+
+        # apply AAS to modify the attention module
+        if self.do_self_attention_redirection_guidance:
+            self._AAS_end_step = int(strength * self._num_timesteps)
+            layer_idx=list(range(self._AAS_start_layer, self._AAS_end_layer))
+            editor = AAS_XL(self._AAS_start_step, self._AAS_end_step, self._AAS_start_layer, self._AAS_end_layer, layer_idx= layer_idx, mask=mask_image,model_type="SDXL",ss_steps=self._ss_steps,ss_scale=self._ss_scale)
+            self.regiter_attention_editor_diffusers(self.unet, editor)
+
+
+        # 11. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        if (
+            self.denoising_end is not None
+            and self.denoising_start is not None
+            and denoising_value_valid(self.denoising_end)
+            and denoising_value_valid(self.denoising_start)
+            and self.denoising_start >= self.denoising_end
+        ):
+            raise ValueError(
+                f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+                + f" {self.denoising_end} when using type float."
+            )
+        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+                )
+            )
+            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+            timesteps = timesteps[:num_inference_steps]
+
+        # 11.1 Optionally get Guidance Scale Embedding
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        self._num_timesteps = len(timesteps)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+                #removal guidance
+                latent_model_input = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+                #latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+                
+                # concat latents, mask, masked_image_latents in the channel dimension
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+                #latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+
+                if num_channels_unet == 9:
+                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+                # predict the noise residual
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+                    added_cond_kwargs["image_embeds"] = image_embeds
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep_cond=timestep_cond,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform SARG
+                if self.do_self_attention_redirection_guidance:
+                    noise_pred_wo, noise_pred_w = noise_pred.chunk(2)
+                    delta = noise_pred_w - noise_pred_wo
+                    noise_pred = noise_pred_wo + self._rm_guidance_scale * delta
+
+                # perform guidance
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if num_channels_unet == 4:
+                    init_latents_proper = image_latents
+                    if self.do_classifier_free_guidance:
+                        init_mask, _ = mask.chunk(2)
+                    else:
+                        init_mask = mask
+
+                    if i < len(timesteps) - 1:
+                        noise_timestep = timesteps[i + 1]
+                        init_latents_proper = self.scheduler.add_noise(
+                            init_latents_proper, noise, torch.tensor([noise_timestep])
+                        )
+
+                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+                    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+                    negative_pooled_prompt_embeds = callback_outputs.pop(
+                        "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+                    )
+                    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+                    add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+                    mask = callback_outputs.pop("mask", mask)
+                    masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+                if XLA_AVAILABLE:
+                    xm.mark_step()
+
+        if not output_type == "latent":
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            latents = latents[-1:]
+
+            if needs_upcasting:
+                self.upcast_vae()
+                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+            # unscale/denormalize the latents
+            # denormalize with the mean and std if available and not None
+            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+            if has_latents_mean and has_latents_std:
+                latents_mean = (
+                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents_std = (
+                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+            else:
+                latents = latents / self.vae.config.scaling_factor
+
+            image = self.vae.decode(latents, return_dict=False)[0]
+
+            # cast back to fp16 if needed
+            if needs_upcasting:
+                self.vae.to(dtype=torch.float16)
+        else:
+            return StableDiffusionXLPipelineOutput(images=latents)
+
+        # apply watermark if available
+        if self.watermark is not None:
+            image = self.watermark.apply_watermark(image)
+
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+        if padding_mask_crop is not None:
+            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
From a2f805d6c9a6881c025e5a9657c68c4717f7aac7 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Tue, 14 Jan 2025 23:22:17 +0800
Subject: [PATCH 2/7] add
 pipeline_stable_diffusion_xl_attentive_eraser_make_style
---
 ...ne_stable_diffusion_xl_attentive_eraser.py | 1014 ++++++++++++-----
 1 file changed, 718 insertions(+), 296 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index cebd80c8bfa1..764a8c1defcb 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -14,52 +14,41 @@
 
 import inspect
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-from tqdm import tqdm
+
 import numpy as np
 import PIL.Image
-from PIL import Image
 import torch
-from transformers import (
-    CLIPImageProcessor,
-    CLIPTextModel,
-    CLIPTextModelWithProjection,
-    CLIPTokenizer,
-    CLIPVisionModelWithProjection,
-)
+from PIL import Image
+from transformers import (CLIPImageProcessor, CLIPTextModel,
+                          CLIPTextModelWithProjection, CLIPTokenizer,
+                          CLIPVisionModelWithProjection)
 
 from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
-from diffusers.loaders import (
-    FromSingleFileMixin,
-    IPAdapterMixin,
-    StableDiffusionXLLoraLoaderMixin,
-    TextualInversionLoaderMixin,
-)
-from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from diffusers.models.attention_processor import (
-    AttnProcessor2_0,
-    LoRAAttnProcessor2_0,
-    LoRAXFormersAttnProcessor,
-    XFormersAttnProcessor,
-)
+from diffusers.loaders import (FromSingleFileMixin, IPAdapterMixin,
+                               StableDiffusionXLLoraLoaderMixin,
+                               TextualInversionLoaderMixin)
+from diffusers.models import (AutoencoderKL, ImageProjection,
+                              UNet2DConditionModel)
+from diffusers.models.attention_processor import (AttnProcessor2_0,
+                                                  LoRAAttnProcessor2_0,
+                                                  LoRAXFormersAttnProcessor,
+                                                  XFormersAttnProcessor)
 from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
+                                                StableDiffusionMixin)
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import \
+    StableDiffusionXLPipelineOutput
 from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import (
-    USE_PEFT_BACKEND,
-    deprecate,
-    is_invisible_watermark_available,
-    is_torch_xla_available,
-    logging,
-    replace_example_docstring,
-    scale_lora_layers,
-    unscale_lora_layers,
-)
+from diffusers.utils import (USE_PEFT_BACKEND, deprecate,
+                             is_invisible_watermark_available,
+                             is_torch_xla_available, logging,
+                             replace_example_docstring, scale_lora_layers,
+                             unscale_lora_layers)
 from diffusers.utils.torch_utils import randn_tensor
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
-from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
-
 
 if is_invisible_watermark_available():
-    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+    from diffusers.pipelines.stable_diffusion_xl.watermark import \
+        StableDiffusionXLWatermarker
 
 if is_torch_xla_available():
     import torch_xla.core.xla_model as xm
@@ -68,10 +57,10 @@
 else:
     XLA_AVAILABLE = False
 
-from einops import rearrange, repeat
+
 import torch.nn as nn
 import torch.nn.functional as F
-import os
+from einops import rearrange, repeat
 
 logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
 
@@ -104,6 +93,7 @@
         ```
 """
 
+
 class AttentionBase:
     def __init__(self):
         self.cur_step = 0
@@ -113,8 +103,12 @@ def __init__(self):
     def after_step(self):
         pass
 
-    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
-        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+    def __call__(
+        self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs
+    ):
+        out = self.forward(
+            q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs
+        )
         self.cur_att_layer += 1
         if self.cur_att_layer == self.num_att_layers:
             self.cur_att_layer = 0
@@ -123,21 +117,34 @@ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwa
             self.after_step()
         return out
 
-    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
-        out = torch.einsum('b i j, b j d -> b i d', attn, v)
-        out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
+    def forward(self, q, k, v, sim, attn, is_cross,
+                place_in_unet, num_heads, **kwargs):
+        out = torch.einsum("b i j, b j d -> b i d", attn, v)
+        out = rearrange(out, "(b h) n d -> b n (h d)", h=num_heads)
         return out
 
     def reset(self):
         self.cur_step = 0
         self.cur_att_layer = 0
 
+
 class AAS_XL(AttentionBase):
-    MODEL_TYPE = {
-        "SD": 16,
-        "SDXL": 70
-    }
-    def __init__(self,  start_step=4, end_step= 50, start_layer=10, end_layer=16,layer_idx=None, step_idx=None, total_steps=50,  mask=None, model_type="SD",ss_steps=9,ss_scale=1.0):
+    MODEL_TYPE = {"SD": 16, "SDXL": 70}
+
+    def __init__(
+        self,
+        start_step=4,
+        end_step=50,
+        start_layer=10,
+        end_layer=16,
+        layer_idx=None,
+        step_idx=None,
+        total_steps=50,
+        mask=None,
+        model_type="SD",
+        ss_steps=9,
+        ss_scale=1.0,
+    ):
         """
         Args:
             start_step: the step to start AAS
@@ -155,49 +162,92 @@ def __init__(self,  start_step=4, end_step= 50, start_layer=10, end_layer=16,lay
         self.end_step = end_step
         self.start_layer = start_layer
         self.end_layer = end_layer
-        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
-        self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+        self.layer_idx = (
+            layer_idx if layer_idx is not None else list(
+                range(start_layer, end_layer))
+        )
+        self.step_idx = (
+            step_idx if step_idx is not None else list(
+                range(start_step, end_step))
+        )
         self.mask = mask  # mask with shape (1, 1 ,h, w)
         self.ss_steps = ss_steps
         self.ss_scale = ss_scale
         print("AAS at denoising steps: ", self.step_idx)
         print("AAS at U-Net layers: ", self.layer_idx)
         print("start AAS")
-        self.mask_16 = F.max_pool2d(mask,(1024//16,1024//16)).round().squeeze().squeeze()
-        self.mask_32 = F.max_pool2d(mask,(1024//32,1024//32)).round().squeeze().squeeze()
-        self.mask_64 = F.max_pool2d(mask,(1024//64,1024//64)).round().squeeze().squeeze()
-        self.mask_128 = F.max_pool2d(mask,(1024//128,1024//128)).round().squeeze().squeeze()
-    
-    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,is_mask_attn, mask, **kwargs):
+        self.mask_16 = (
+            F.max_pool2d(mask, (1024 // 16, 1024 // 16)
+                         ).round().squeeze().squeeze()
+        )
+        self.mask_32 = (
+            F.max_pool2d(mask, (1024 // 32, 1024 // 32)
+                         ).round().squeeze().squeeze()
+        )
+        self.mask_64 = (
+            F.max_pool2d(mask, (1024 // 64, 1024 // 64)
+                         ).round().squeeze().squeeze()
+        )
+        self.mask_128 = (
+            F.max_pool2d(mask, (1024 // 128, 1024 // 128)
+                         ).round().squeeze().squeeze()
+        )
+
+    def attn_batch(
+        self,
+        q,
+        k,
+        v,
+        sim,
+        attn,
+        is_cross,
+        place_in_unet,
+        num_heads,
+        is_mask_attn,
+        mask,
+        **kwargs,
+    ):
         B = q.shape[0] // num_heads
         if is_mask_attn:
             mask_flatten = mask.flatten(0)
-            if self.cur_step <= self.ss_steps:                                                                                                                                                                                                                                           
+            if self.cur_step <= self.ss_steps:
                 # background
-                sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min) 
+                sim_bg = sim + mask_flatten.masked_fill(
+                    mask_flatten == 1, torch.finfo(sim.dtype).min
+                )
 
-                #object
-                sim_fg = self.ss_scale*sim
-                sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+                # object
+                sim_fg = self.ss_scale * sim
+                sim_fg += mask_flatten.masked_fill(
+                    mask_flatten == 1, torch.finfo(sim.dtype).min
+                )
                 sim = torch.cat([sim_fg, sim_bg], dim=0)
             else:
-                sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+                sim += mask_flatten.masked_fill(
+                    mask_flatten == 1, torch.finfo(sim.dtype).min
+                )
 
         attn = sim.softmax(-1)
         if len(attn) == 2 * len(v):
             v = torch.cat([v] * 2)
         out = torch.einsum("h i j, h j d -> h i d", attn, v)
-        out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+        out = rearrange(
+            out,
+            "(h1 h) (b n) d -> (h1 b) n (h d)",
+            b=B,
+            h=num_heads)
         return out
 
-    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+    def forward(self, q, k, v, sim, attn, is_cross,
+                place_in_unet, num_heads, **kwargs):
         """
         Attention forward function
         """
         if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
             return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
-        B = q.shape[0] // num_heads // 2
-        H = W = int(np.sqrt(q.shape[1]))
+        # B = q.shape[0] // num_heads // 2
+        # H = W = int(np.sqrt(q.shape[1]))
+        H = int(np.sqrt(q.shape[1]))
         if H == 16:
             mask = self.mask_16.to(sim.device)
         elif H == 32:
@@ -207,15 +257,38 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
         else:
             mask = self.mask_128.to(sim.device)
 
-
         q_wo, q_w = q.chunk(2)
         k_wo, k_w = k.chunk(2)
         v_wo, v_w = v.chunk(2)
         sim_wo, sim_w = sim.chunk(2)
         attn_wo, attn_w = attn.chunk(2)
 
-        out_source = self.attn_batch(q_wo, k_wo, v_wo, sim_wo, attn_wo, is_cross, place_in_unet, num_heads,is_mask_attn=False,mask=None,**kwargs)
-        out_target = self.attn_batch(q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask = mask,**kwargs)
+        out_source = self.attn_batch(
+            q_wo,
+            k_wo,
+            v_wo,
+            sim_wo,
+            attn_wo,
+            is_cross,
+            place_in_unet,
+            num_heads,
+            is_mask_attn=False,
+            mask=None,
+            **kwargs,
+        )
+        out_target = self.attn_batch(
+            q_w,
+            k_w,
+            v_w,
+            sim_w,
+            attn_w,
+            is_cross,
+            place_in_unet,
+            num_heads,
+            is_mask_attn=True,
+            mask=mask,
+            **kwargs,
+        )
 
         if self.mask is not None:
             if out_target.shape[0] == 2:
@@ -224,11 +297,13 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
                 out_target = out_target_fg * mask + out_target_bg * (1 - mask)
             else:
                 out_target = out_target
-        
+
         out = torch.cat([out_source, out_target], dim=0)
         return out
-    
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+
+
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
 def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
     """
     Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -259,7 +334,9 @@ def mask_pil_to_torch(mask, height, width):
     return mask
 
 
-def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+def prepare_mask_and_masked_image(
+    image, mask, height, width, return_image: bool = False
+):
     """
     Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
     converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
@@ -314,7 +391,8 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
 
         # Batch single mask or add channel dim
         if mask.ndim == 3:
-            # Single batched mask, no channel dim or single mask not batched but channel dim
+            # Single batched mask, no channel dim or single mask not batched
+            # but channel dim
             if mask.shape[0] == 1:
                 mask = mask.unsqueeze(0)
 
@@ -322,9 +400,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
             else:
                 mask = mask.unsqueeze(1)
 
-        assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+        assert (
+            image.ndim == 4 and mask.ndim == 4
+        ), "Image and Mask must have 4 dimensions"
         # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
-        assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+        assert (
+            image.shape[0] == mask.shape[0]
+        ), "Image and Mask must have the same batch size"
 
         # Check image is in [-1, 1]
         # if image.min() < -1 or image.max() > 1:
@@ -341,14 +423,18 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
         # Image as float32
         image = image.to(dtype=torch.float32)
     elif isinstance(mask, torch.Tensor):
-        raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+        raise TypeError(
+            f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
+        )
     else:
         # preprocess image
         if isinstance(image, (PIL.Image.Image, np.ndarray)):
             image = [image]
         if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
             # resize all images w.r.t passed height an width
-            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+            image = [
+                i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image
+            ]
             image = [np.array(i.convert("RGB"))[None, :] for i in image]
             image = np.concatenate(image, axis=0)
         elif isinstance(image, list) and isinstance(image[0], np.ndarray):
@@ -377,9 +463,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
     return mask, masked_image
 
 
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
 def retrieve_latents(
-    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+    encoder_output: torch.Tensor,
+    generator: Optional[torch.Generator] = None,
+    sample_mode: str = "sample",
 ):
     if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
         return encoder_output.latent_dist.sample(generator)
@@ -388,10 +477,12 @@ def retrieve_latents(
     elif hasattr(encoder_output, "latents"):
         return encoder_output.latents
     else:
-        raise AttributeError("Could not access latents of provided encoder_output")
+        raise AttributeError(
+            "Could not access latents of provided encoder_output")
 
 
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
 def retrieve_timesteps(
     scheduler,
     num_inference_steps: Optional[int] = None,
@@ -421,7 +512,9 @@ def retrieve_timesteps(
         second element is the number of inference steps.
     """
     if timesteps is not None:
-        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        accepts_timesteps = "timesteps" in set(
+            inspect.signature(scheduler.set_timesteps).parameters.keys()
+        )
         if not accepts_timesteps:
             raise ValueError(
                 f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -542,55 +635,86 @@ def __init__(
             feature_extractor=feature_extractor,
             scheduler=scheduler,
         )
-        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
-        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
-        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
-        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.register_to_config(
+            force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
+        )
+        self.register_to_config(
+            requires_aesthetics_score=requires_aesthetics_score)
+        self.vae_scale_factor = 2 ** (
+            len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor)
         self.mask_processor = VaeImageProcessor(
-            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+            vae_scale_factor=self.vae_scale_factor,
+            do_normalize=False,
+            do_binarize=True,
+            do_convert_grayscale=True,
         )
 
-        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+        add_watermarker = (
+            add_watermarker
+            if add_watermarker is not None
+            else is_invisible_watermark_available()
+        )
 
         if add_watermarker:
             self.watermark = StableDiffusionXLWatermarker()
         else:
             self.watermark = None
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
-    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+    # Copied from
+    # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(
+        self, image, device, num_images_per_prompt, output_hidden_states=None
+    ):
         dtype = next(self.image_encoder.parameters()).dtype
 
         if not isinstance(image, torch.Tensor):
-            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+            image = self.feature_extractor(
+                image, return_tensors="pt").pixel_values
 
         image = image.to(device=device, dtype=dtype)
         if output_hidden_states:
-            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
-            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            image_enc_hidden_states = self.image_encoder(
+                image, output_hidden_states=True
+            ).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
             uncond_image_enc_hidden_states = self.image_encoder(
                 torch.zeros_like(image), output_hidden_states=True
             ).hidden_states[-2]
-            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
-                num_images_per_prompt, dim=0
+            uncond_image_enc_hidden_states = (
+                uncond_image_enc_hidden_states.repeat_interleave(
+                    num_images_per_prompt, dim=0
+                )
             )
             return image_enc_hidden_states, uncond_image_enc_hidden_states
         else:
             image_embeds = self.image_encoder(image).image_embeds
-            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            image_embeds = image_embeds.repeat_interleave(
+                num_images_per_prompt, dim=0)
             uncond_image_embeds = torch.zeros_like(image_embeds)
 
             return image_embeds, uncond_image_embeds
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+    # Copied from
+    # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
     def prepare_ip_adapter_image_embeds(
-        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+        self,
+        ip_adapter_image,
+        ip_adapter_image_embeds,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
     ):
         if ip_adapter_image_embeds is None:
             if not isinstance(ip_adapter_image, list):
                 ip_adapter_image = [ip_adapter_image]
 
-            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+            if len(ip_adapter_image) != len(
+                self.unet.encoder_hid_proj.image_projection_layers
+            ):
                 raise ValueError(
                     f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                 )
@@ -634,7 +758,8 @@ def prepare_ip_adapter_image_embeds(
 
         return image_embeds
 
-    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+    # Copied from
+    # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
     def encode_prompt(
         self,
         prompt: str,
@@ -697,19 +822,23 @@ def encode_prompt(
 
         # set lora scale so that monkey patched LoRA
         # function of text encoder can correctly access it
-        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+        if lora_scale is not None and isinstance(
+            self, StableDiffusionXLLoraLoaderMixin
+        ):
             self._lora_scale = lora_scale
 
             # dynamically adjust the LoRA scale
             if self.text_encoder is not None:
                 if not USE_PEFT_BACKEND:
-                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+                    adjust_lora_scale_text_encoder(
+                        self.text_encoder, lora_scale)
                 else:
                     scale_lora_layers(self.text_encoder, lora_scale)
 
             if self.text_encoder_2 is not None:
                 if not USE_PEFT_BACKEND:
-                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+                    adjust_lora_scale_text_encoder(
+                        self.text_encoder_2, lora_scale)
                 else:
                     scale_lora_layers(self.text_encoder_2, lora_scale)
 
@@ -721,9 +850,15 @@ def encode_prompt(
             batch_size = prompt_embeds.shape[0]
 
         # Define tokenizers and text encoders
-        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        tokenizers = (
+            [self.tokenizer, self.tokenizer_2]
+            if self.tokenizer is not None
+            else [self.tokenizer_2]
+        )
         text_encoders = (
-            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+            [self.text_encoder, self.text_encoder_2]
+            if self.text_encoder is not None
+            else [self.text_encoder_2]
         )
 
         if prompt_embeds is None:
@@ -733,7 +868,9 @@ def encode_prompt(
             # textual inversion: process multi-vector tokens if necessary
             prompt_embeds_list = []
             prompts = [prompt, prompt_2]
-            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+            for prompt, tokenizer, text_encoder in zip(
+                prompts, tokenizers, text_encoders
+            ):
                 if isinstance(self, TextualInversionLoaderMixin):
                     prompt = self.maybe_convert_prompt(prompt, tokenizer)
 
@@ -746,26 +883,34 @@ def encode_prompt(
                 )
 
                 text_input_ids = text_inputs.input_ids
-                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
-
-                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
-                    text_input_ids, untruncated_ids
-                ):
-                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                untruncated_ids = tokenizer(
+                    prompt, padding="longest", return_tensors="pt"
+                ).input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[
+                    -1
+                ] and not torch.equal(text_input_ids, untruncated_ids):
+                    removed_text = tokenizer.batch_decode(
+                        untruncated_ids[:, tokenizer.model_max_length - 1: -1]
+                    )
                     logger.warning(
                         "The following part of your input was truncated because CLIP can only handle sequences up to"
                         f" {tokenizer.model_max_length} tokens: {removed_text}"
                     )
 
-                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+                prompt_embeds = text_encoder(
+                    text_input_ids.to(device), output_hidden_states=True
+                )
 
-                # We are only ALWAYS interested in the pooled output of the final text encoder
+                # We are only ALWAYS interested in the pooled output of the
+                # final text encoder
                 pooled_prompt_embeds = prompt_embeds[0]
                 if clip_skip is None:
                     prompt_embeds = prompt_embeds.hidden_states[-2]
                 else:
                     # "2" because SDXL always indexes from the penultimate layer.
-                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+                    prompt_embeds = prompt_embeds.hidden_states[-(
+                        clip_skip + 2)]
 
                 prompt_embeds_list.append(prompt_embeds)
 
@@ -819,70 +964,99 @@ def encode_prompt(
                     uncond_input.input_ids.to(device),
                     output_hidden_states=True,
                 )
-                # We are only ALWAYS interested in the pooled output of the final text encoder
+                # We are only ALWAYS interested in the pooled output of the
+                # final text encoder
                 negative_pooled_prompt_embeds = negative_prompt_embeds[0]
                 negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
 
                 negative_prompt_embeds_list.append(negative_prompt_embeds)
 
-            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+            negative_prompt_embeds = torch.concat(
+                negative_prompt_embeds_list, dim=-1)
 
         if self.text_encoder_2 is not None:
-            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+            prompt_embeds = prompt_embeds.to(
+                dtype=self.text_encoder_2.dtype, device=device
+            )
         else:
-            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+            prompt_embeds = prompt_embeds.to(
+                dtype=self.unet.dtype, device=device)
 
         bs_embed, seq_len, _ = prompt_embeds.shape
-        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        # duplicate text embeddings for each generation per prompt, using mps
+        # friendly method
         prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
-        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+        prompt_embeds = prompt_embeds.view(
+            bs_embed * num_images_per_prompt, seq_len, -1
+        )
 
         if do_classifier_free_guidance:
-            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            # duplicate unconditional embeddings for each generation per
+            # prompt, using mps friendly method
             seq_len = negative_prompt_embeds.shape[1]
 
             if self.text_encoder_2 is not None:
-                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+                negative_prompt_embeds = negative_prompt_embeds.to(
+                    dtype=self.text_encoder_2.dtype, device=device
+                )
             else:
-                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+                negative_prompt_embeds = negative_prompt_embeds.to(
+                    dtype=self.unet.dtype, device=device
+                )
 
-            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
-            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+            negative_prompt_embeds = negative_prompt_embeds.repeat(
+                1, num_images_per_prompt, 1
+            )
+            negative_prompt_embeds = negative_prompt_embeds.view(
+                batch_size * num_images_per_prompt, seq_len, -1
+            )
 
-        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
-            bs_embed * num_images_per_prompt, -1
-        )
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(
+            1, num_images_per_prompt
+        ).view(bs_embed * num_images_per_prompt, -1)
         if do_classifier_free_guidance:
-            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
-                bs_embed * num_images_per_prompt, -1
-            )
+            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
+                1, num_images_per_prompt
+            ).view(bs_embed * num_images_per_prompt, -1)
 
         if self.text_encoder is not None:
-            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+            if isinstance(
+                    self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
                 # Retrieve the original scale by scaling back the LoRA layers
                 unscale_lora_layers(self.text_encoder, lora_scale)
 
         if self.text_encoder_2 is not None:
-            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+            if isinstance(
+                    self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
                 # Retrieve the original scale by scaling back the LoRA layers
                 unscale_lora_layers(self.text_encoder_2, lora_scale)
 
-        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+        return (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        )
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    # Copied from
+    # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
     def prepare_extra_step_kwargs(self, generator, eta):
         # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         # and should be between [0, 1]
 
-        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        accepts_eta = "eta" in set(
+            inspect.signature(self.scheduler.step).parameters.keys()
+        )
         extra_step_kwargs = {}
         if accepts_eta:
             extra_step_kwargs["eta"] = eta
 
         # check if the scheduler accepts generator
-        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        accepts_generator = "generator" in set(
+            inspect.signature(self.scheduler.step).parameters.keys()
+        )
         if accepts_generator:
             extra_step_kwargs["generator"] = generator
         return extra_step_kwargs
@@ -908,19 +1082,26 @@ def check_inputs(
         padding_mask_crop=None,
     ):
         if strength < 0 or strength > 1:
-            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+            raise ValueError(
+                f"The value of strength should in [0.0, 1.0] but is {strength}"
+            )
 
         if height % 8 != 0 or width % 8 != 0:
-            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+            raise ValueError(
+                f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+            )
 
-        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+        if callback_steps is not None and (
+            not isinstance(callback_steps, int) or callback_steps <= 0
+        ):
             raise ValueError(
                 f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                 f" {type(callback_steps)}."
             )
 
         if callback_on_step_end_tensor_inputs is not None and not all(
-            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+            k in self._callback_tensor_inputs
+            for k in callback_on_step_end_tensor_inputs
         ):
             raise ValueError(
                 f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -940,10 +1121,18 @@ def check_inputs(
             raise ValueError(
                 "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
             )
-        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
-            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
-            raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+        elif prompt is not None and (
+            not isinstance(prompt, str) and not isinstance(prompt, list)
+        ):
+            raise ValueError(
+                f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+            )
+        elif prompt_2 is not None and (
+            not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
+        ):
+            raise ValueError(
+                f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
+            )
 
         if negative_prompt is not None and negative_prompt_embeds is not None:
             raise ValueError(
@@ -966,7 +1155,8 @@ def check_inputs(
         if padding_mask_crop is not None:
             if not isinstance(image, PIL.Image.Image):
                 raise ValueError(
-                    f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+                    f"The image should be a PIL image when inpainting mask crop, but is of type"
+                    f" {type(image)}."
                 )
             if not isinstance(mask_image, PIL.Image.Image):
                 raise ValueError(
@@ -974,7 +1164,10 @@ def check_inputs(
                     f" {type(mask_image)}."
                 )
             if output_type != "pil":
-                raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+                raise ValueError(
+                    f"The output type should be PIL when inpainting mask crop, but is"
+                    f" {output_type}."
+                )
 
         if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
             raise ValueError(
@@ -1008,7 +1201,12 @@ def prepare_latents(
         return_noise=False,
         return_image_latents=False,
     ):
-        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        shape = (
+            batch_size,
+            num_channels_latents,
+            height // self.vae_scale_factor,
+            width // self.vae_scale_factor,
+        )
         if isinstance(generator, list) and len(generator) != batch_size:
             raise ValueError(
                 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1023,23 +1221,46 @@ def prepare_latents(
 
         if image.shape[1] == 4:
             image_latents = image.to(device=device, dtype=dtype)
-            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+            image_latents = image_latents.repeat(
+                batch_size // image_latents.shape[0], 1, 1, 1
+            )
         elif return_image_latents or (latents is None and not is_strength_max):
             image = image.to(device=device, dtype=dtype)
-            image_latents = self._encode_vae_image(image=image, generator=generator)
-            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+            image_latents = self._encode_vae_image(
+                image=image, generator=generator)
+            image_latents = image_latents.repeat(
+                batch_size // image_latents.shape[0], 1, 1, 1
+            )
 
         if latents is None and add_noise:
-            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
-            # if strength is 1. then initialise the latents to noise, else initial to image + noise
-            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
-            # if pure noise then scale the initial latents by the  Scheduler's init sigma
-            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+            noise = randn_tensor(
+                shape,
+                generator=generator,
+                device=device,
+                dtype=dtype)
+            # if strength is 1. then initialise the latents to noise, else
+            # initial to image + noise
+            latents = (
+                noise
+                if is_strength_max
+                else self.scheduler.add_noise(image_latents, noise, timestep)
+            )
+            # if pure noise then scale the initial latents by the  Scheduler's
+            # init sigma
+            latents = (
+                latents * self.scheduler.init_noise_sigma
+                if is_strength_max
+                else latents
+            )
         elif add_noise:
             noise = latents.to(device)
             latents = noise * self.scheduler.init_noise_sigma
         else:
-            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            noise = randn_tensor(
+                shape,
+                generator=generator,
+                device=device,
+                dtype=dtype)
             latents = image_latents.to(device)
 
         outputs = (latents,)
@@ -1052,7 +1273,8 @@ def prepare_latents(
 
         return outputs
 
-    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+    def _encode_vae_image(self, image: torch.Tensor,
+                          generator: torch.Generator):
         dtype = image.dtype
         if self.vae.config.force_upcast:
             image = image.float()
@@ -1060,12 +1282,16 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
 
         if isinstance(generator, list):
             image_latents = [
-                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+                retrieve_latents(
+                    self.vae.encode(image[i: i + 1]), generator=generator[i]
+                )
                 for i in range(image.shape[0])
             ]
             image_latents = torch.cat(image_latents, dim=0)
         else:
-            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+            image_latents = retrieve_latents(
+                self.vae.encode(image), generator=generator
+            )
 
         if self.vae.config.force_upcast:
             self.vae.to(dtype)
@@ -1076,18 +1302,28 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
         return image_latents
 
     def prepare_mask_latents(
-        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+        self,
+        mask,
+        masked_image,
+        batch_size,
+        height,
+        width,
+        dtype,
+        device,
+        generator,
+        do_classifier_free_guidance,
     ):
         # resize the mask to latents shape as we concatenate the mask to the latents
         # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
         # and half precision
-        #mask = torch.nn.functional.interpolate(
+        # mask = torch.nn.functional.interpolate(
         #    mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
-        #)
-        mask = torch.nn.functional.max_pool2d(mask, (8,8)).round()
+        # )
+        mask = torch.nn.functional.max_pool2d(mask, (8, 8)).round()
         mask = mask.to(device=device, dtype=dtype)
 
-        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+        # duplicate mask and masked_image_latents for each generation per
+        # prompt, using mps friendly method
         if mask.shape[0] < batch_size:
             if not batch_size % mask.shape[0] == 0:
                 raise ValueError(
@@ -1107,7 +1343,9 @@ def prepare_mask_latents(
         if masked_image is not None:
             if masked_image_latents is None:
                 masked_image = masked_image.to(device=device, dtype=dtype)
-                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+                masked_image_latents = self._encode_vae_image(
+                    masked_image, generator=generator
+                )
 
             if masked_image_latents.shape[0] < batch_size:
                 if not batch_size % masked_image_latents.shape[0] == 0:
@@ -1121,16 +1359,23 @@ def prepare_mask_latents(
                 )
 
             masked_image_latents = (
-                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+                torch.cat([masked_image_latents] * 2)
+                if do_classifier_free_guidance
+                else masked_image_latents
             )
 
-            # aligning device to prevent device errors when concating it with the latent model input
-            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+            # aligning device to prevent device errors when concating it with
+            # the latent model input
+            masked_image_latents = masked_image_latents.to(
+                device=device, dtype=dtype)
 
         return mask, masked_image_latents
 
-    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
-    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+    # Copied from
+    # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+    def get_timesteps(
+        self, num_inference_steps, strength, device, denoising_start=None
+    ):
         # get the original timestep using init_timestep
         if denoising_start is None:
             init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
@@ -1138,7 +1383,7 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
         else:
             t_start = 0
 
-        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
 
         # Strength is irrelevant if we directly request a timestep to start at;
         # that is, strength is determined by the denoising_start instead.
@@ -1157,16 +1402,19 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
                 # (except the highest one) is duplicated. If `num_inference_steps` is even it would
                 # mean that we cut the timesteps in the middle of the denoising step
                 # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
-                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+                # we ensure that the denoising process always ends after the
+                # 2nd derivate step of the scheduler
                 num_inference_steps = num_inference_steps + 1
 
-            # because t_n+1 >= t_n, we slice the timesteps starting from the end
+            # because t_n+1 >= t_n, we slice the timesteps starting from the
+            # end
             timesteps = timesteps[-num_inference_steps:]
             return timesteps, num_inference_steps
 
         return timesteps, num_inference_steps - t_start
 
-    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+    # Copied from
+    # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
     def _get_add_time_ids(
         self,
         original_size,
@@ -1218,7 +1466,8 @@ def _get_add_time_ids(
 
         return add_time_ids, add_neg_time_ids
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+    # Copied from
+    # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
     def upcast_vae(self):
         dtype = self.vae.dtype
         self.vae.to(dtype=torch.float32)
@@ -1238,8 +1487,10 @@ def upcast_vae(self):
             self.vae.decoder.conv_in.to(dtype)
             self.vae.decoder.mid_block.to(dtype)
 
-    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
-    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+    # Copied from
+    # diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+    def get_guidance_scale_embedding(
+            self, w, embedding_dim=512, dtype=torch.float32):
         """
         See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
 
@@ -1279,18 +1530,21 @@ def guidance_rescale(self):
     def clip_skip(self):
         return self._clip_skip
 
-
     @property
-    def do_self_attention_redirection_guidance(self): #SARG
+    def do_self_attention_redirection_guidance(self):  # SARG
         return self._rm_guidance_scale > 1 and self._AAS
-    
+
     # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
     # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
     # corresponds to doing no classifier free guidance.
     @property
     def do_classifier_free_guidance(self):
-        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None and self.do_self_attention_redirection_guidance==False #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
-    
+        return (
+            self._guidance_scale > 1
+            and self.unet.config.time_cond_proj_dim is None
+            and not self.do_self_attention_redirection_guidance
+        )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+
     @property
     def cross_attention_kwargs(self):
         return self._cross_attention_kwargs
@@ -1313,25 +1567,28 @@ def interrupt(self):
 
     @torch.no_grad()
     def image2latent(self, image: torch.Tensor, generator: torch.Generator):
-        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        DEVICE = (
+            torch.device("cuda") if torch.cuda.is_available(
+            ) else torch.device("cpu")
+        )
         if type(image) is Image:
             image = np.array(image)
             image = torch.from_numpy(image).float() / 127.5 - 1
             image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
         # input image density range [-1, 1]
-        #latents = self.vae.encode(image)['latent_dist'].mean
+        # latents = self.vae.encode(image)['latent_dist'].mean
         latents = self._encode_vae_image(image, generator)
-        #latents = retrieve_latents(self.vae.encode(image))
-        #latents = latents * self.vae.config.scaling_factor
+        # latents = retrieve_latents(self.vae.encode(image))
+        # latents = latents * self.vae.config.scaling_factor
         return latents
-    
+
     def next_step(
         self,
         model_output: torch.FloatTensor,
         timestep: int,
         x: torch.FloatTensor,
-        eta=0.,
-        verbose=False
+        eta=0.0,
+        verbose=False,
     ):
         """
         Inverse sampling for DDIM Inversion
@@ -1339,15 +1596,24 @@ def next_step(
         if verbose:
             print("timestep: ", timestep)
         next_step = timestep
-        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
-        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+        timestep = min(
+            timestep
+            - self.scheduler.config.num_train_timesteps
+            // self.scheduler.num_inference_steps,
+            999,
+        )
+        alpha_prod_t = (
+            self.scheduler.alphas_cumprod[timestep]
+            if timestep >= 0
+            else self.scheduler.final_alpha_cumprod
+        )
         alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
         beta_prod_t = 1 - alpha_prod_t
         pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
-        pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
+        pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output
         x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
         return x_next, pred_x0
-    
+
     @torch.no_grad()
     def invert(
         self,
@@ -1362,11 +1628,15 @@ def invert(
         aesthetic_score: float = 6.0,
         negative_aesthetic_score: float = 2.5,
         return_intermediates=False,
-        **kwds):
+        **kwds,
+    ):
         """
         invert a real image into noise map with determinisc DDIM inversion
         """
-        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        DEVICE = (
+            torch.device("cuda") if torch.cuda.is_available(
+            ) else torch.device("cpu")
+        )
         batch_size = image.shape[0]
         if isinstance(prompt, list):
             if batch_size == 1:
@@ -1376,9 +1646,15 @@ def invert(
                 prompt = [prompt] * batch_size
 
         # Define tokenizers and text encoders
-        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        tokenizers = (
+            [self.tokenizer, self.tokenizer_2]
+            if self.tokenizer is not None
+            else [self.tokenizer_2]
+        )
         text_encoders = (
-            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+            [self.text_encoder, self.text_encoder_2]
+            if self.text_encoder is not None
+            else [self.text_encoder_2]
         )
 
         prompt_2 = prompt
@@ -1387,7 +1663,8 @@ def invert(
         # textual inversion: process multi-vector tokens if necessary
         prompt_embeds_list = []
         prompts = [prompt, prompt_2]
-        for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+        for prompt, tokenizer, text_encoder in zip(
+                prompts, tokenizers, text_encoders):
             if isinstance(self, TextualInversionLoaderMixin):
                 prompt = self.maybe_convert_prompt(prompt, tokenizer)
 
@@ -1400,20 +1677,27 @@ def invert(
             )
 
             text_input_ids = text_inputs.input_ids
-            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
-
-            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
-                text_input_ids, untruncated_ids
-            ):
-                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+            untruncated_ids = tokenizer(
+                prompt, padding="longest", return_tensors="pt"
+            ).input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[
+                -1
+            ] and not torch.equal(text_input_ids, untruncated_ids):
+                removed_text = tokenizer.batch_decode(
+                    untruncated_ids[:, tokenizer.model_max_length - 1: -1]
+                )
                 logger.warning(
                     "The following part of your input was truncated because CLIP can only handle sequences up to"
                     f" {tokenizer.model_max_length} tokens: {removed_text}"
                 )
 
-            prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+            prompt_embeds = text_encoder(
+                text_input_ids.to(DEVICE), output_hidden_states=True
+            )
 
-            # We are only ALWAYS interested in the pooled output of the final text encoder
+            # We are only ALWAYS interested in the pooled output of the final
+            # text encoder
             pooled_prompt_embeds = prompt_embeds[0]
             prompt_embeds = prompt_embeds.hidden_states[-2]
             prompt_embeds_list.append(prompt_embeds)
@@ -1421,7 +1705,7 @@ def invert(
         prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
 
         # define initial latents
-        latents = self.image2latent(image,generator=None)
+        latents = self.image2latent(image, generator=None)
 
         start_latents = latents
         height, width = latents.shape[-2:]
@@ -1454,32 +1738,41 @@ def invert(
         self.scheduler.set_timesteps(num_inference_steps)
         latents_list = [latents]
         pred_x0_list = []
-        #for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+        # for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps),
+        # desc="DDIM Inversion")):
         for i, t in enumerate(reversed(self.scheduler.timesteps)):
             model_inputs = latents
 
             # predict the noise
-            added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
-            noise_pred = self.unet(model_inputs, t, encoder_hidden_states=prompt_embeds,added_cond_kwargs=added_cond_kwargs).sample
+            added_cond_kwargs = {
+                "text_embeds": add_text_embeds,
+                "time_ids": add_time_ids,
+            }
+            noise_pred = self.unet(
+                model_inputs,
+                t,
+                encoder_hidden_states=prompt_embeds,
+                added_cond_kwargs=added_cond_kwargs,
+            ).sample
 
             # compute the previous noise sample x_t-1 -> x_t
             latents, pred_x0 = self.next_step(noise_pred, t, latents)
-            """             
+            """
             if t >= 1 and t < 41:
                 latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
             else:
                 latents, pred_x0 = self.next_step(noise_pred, t, latents) """
-            
+
             latents_list.append(latents)
             pred_x0_list.append(pred_x0)
 
         if return_intermediates:
             # return the intermediate laters during inversion
-            #pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
-            #latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+            # pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+            # latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
             return latents, latents_list, pred_x0_list
         return latents, start_latents
-    
+
     def opt(
         self,
         model_output: torch.FloatTensor,
@@ -1489,19 +1782,27 @@ def opt(
         """
         predict the sampe the next step in the denoise process.
         """
-        ref_noise = model_output[:1,:,:,:].expand(model_output.shape)
+        ref_noise = model_output[:1, :, :, :].expand(model_output.shape)
         alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
         beta_prod_t = 1 - alpha_prod_t
         pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
-        x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t)**0.5 * ref_noise
+        x_opt = alpha_prod_t**0.5 * pred_x0 + \
+            (1 - alpha_prod_t) ** 0.5 * ref_noise
         return x_opt, pred_x0
-    
+
     def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
         """
         Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
         """
+
         def ca_forward(self, place_in_unet):
-            def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+            def forward(
+                x,
+                encoder_hidden_states=None,
+                attention_mask=None,
+                context=None,
+                mask=None,
+            ):
                 """
                 The attention is similar to the original implementation of LDM CrossAttention class
                 except adding some modifications on the attention
@@ -1523,22 +1824,33 @@ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, ma
                 context = context if is_cross else x
                 k = self.to_k(context)
                 v = self.to_v(context)
-                q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+                q, k, v = map(
+                    lambda t: rearrange(
+                        t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
+                )
 
-                sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+                sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
 
                 if mask is not None:
-                    mask = rearrange(mask, 'b ... -> b (...)')
+                    mask = rearrange(mask, "b ... -> b (...)")
                     max_neg_value = -torch.finfo(sim.dtype).max
-                    mask = repeat(mask, 'b j -> (b h) () j', h=h)
+                    mask = repeat(mask, "b j -> (b h) () j", h=h)
                     mask = mask[:, None, :].repeat(h, 1, 1)
                     sim.masked_fill_(~mask, max_neg_value)
 
                 attn = sim.softmax(dim=-1)
                 # the only difference
                 out = editor(
-                    q, k, v, sim, attn, is_cross, place_in_unet,
-                    self.heads, scale=self.scale)
+                    q,
+                    k,
+                    v,
+                    sim,
+                    attn,
+                    is_cross,
+                    place_in_unet,
+                    self.heads,
+                    scale=self.scale,
+                )
 
                 return to_out(out)
 
@@ -1546,10 +1858,10 @@ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, ma
 
         def register_editor(net, count, place_in_unet):
             for name, subnet in net.named_children():
-                if net.__class__.__name__ == 'Attention':  # spatial Transformer layer
+                if net.__class__.__name__ == "Attention":  # spatial Transformer layer
                     net.forward = ca_forward(net, place_in_unet)
                     return count + 1
-                elif hasattr(net, 'children'):
+                elif hasattr(net, "children"):
                     count = register_editor(subnet, count, place_in_unet)
             return count
 
@@ -1577,12 +1889,12 @@ def __call__(
         padding_mask_crop: Optional[int] = None,
         strength: float = 0.9999,
         AAS: bool = True,  # AE parameter
-        rm_guidance_scale: float = 7.0, # AE parameter
+        rm_guidance_scale: float = 7.0,  # AE parameter
         ss_steps: int = 9,  # AE parameter
-        ss_scale: float = 0.3, # AE parameter
-        AAS_start_step: int = 0, # AE parameter
-        AAS_start_layer: int = 34, # AE parameter
-        AAS_end_layer: int = 70, # AE parameter
+        ss_scale: float = 0.3,  # AE parameter
+        AAS_start_step: int = 0,  # AE parameter
+        AAS_start_layer: int = 34,  # AE parameter
+        AAS_end_layer: int = 70,  # AE parameter
         num_inference_steps: int = 50,
         timesteps: List[int] = None,
         denoising_start: Optional[float] = None,
@@ -1592,7 +1904,8 @@ def __call__(
         negative_prompt_2: Optional[Union[str, List[str]]] = None,
         num_images_per_prompt: Optional[int] = 1,
         eta: float = 0.0,
-        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        generator: Optional[Union[torch.Generator,
+                                  List[torch.Generator]]] = None,
         latents: Optional[torch.FloatTensor] = None,
         prompt_embeds: Optional[torch.FloatTensor] = None,
         negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -1613,7 +1926,8 @@ def __call__(
         aesthetic_score: float = 6.0,
         negative_aesthetic_score: float = 2.5,
         clip_skip: Optional[int] = None,
-        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end: Optional[Callable[[
+            int, int, Dict], None]] = None,
         callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         **kwargs,
     ):
@@ -1843,7 +2157,7 @@ def __call__(
         self._denoising_start = denoising_start
         self._interrupt = False
 
-        ########### AE parameters
+        # AE parameters
         self._num_timesteps = num_inference_steps
         self._rm_guidance_scale = rm_guidance_scale
         self._AAS = AAS
@@ -1866,7 +2180,9 @@ def __call__(
 
         # 3. Encode input prompt
         text_encoder_lora_scale = (
-            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+            self.cross_attention_kwargs.get("scale", None)
+            if self.cross_attention_kwargs is not None
+            else None
         )
 
         (
@@ -1894,27 +2210,39 @@ def __call__(
         def denoising_value_valid(dnv):
             return isinstance(dnv, float) and 0 < dnv < 1
 
-        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        timesteps, num_inference_steps = retrieve_timesteps(
+            self.scheduler, num_inference_steps, device, timesteps
+        )
         timesteps, num_inference_steps = self.get_timesteps(
             num_inference_steps,
             strength,
             device,
-            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+            denoising_start=(
+                self.denoising_start
+                if denoising_value_valid(self.denoising_start)
+                else None
+            ),
         )
-        # check that number of inference steps is not < 1 - as this doesn't make sense
+        # check that number of inference steps is not < 1 - as this doesn't
+        # make sense
         if num_inference_steps < 1:
             raise ValueError(
                 f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
                 f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
             )
-        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
-        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
-        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+        # at which timestep to set the initial noise (n.b. 50% if strength is
+        # 0.5)
+        latent_timestep = timesteps[:1].repeat(
+            batch_size * num_images_per_prompt)
+        # create a boolean to check if the strength is set to 1. if so then
+        # initialise the latents with pure noise
         is_strength_max = strength == 1.0
 
         # 5. Preprocess mask and image
         if padding_mask_crop is not None:
-            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+            crops_coords = self.mask_processor.get_crop_region(
+                mask_image, width, height, pad=padding_mask_crop
+            )
             resize_mode = "fill"
         else:
             crops_coords = None
@@ -1922,12 +2250,20 @@ def denoising_value_valid(dnv):
 
         original_image = image
         init_image = self.image_processor.preprocess(
-            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+            image,
+            height=height,
+            width=width,
+            crops_coords=crops_coords,
+            resize_mode=resize_mode,
         )
         init_image = init_image.to(dtype=torch.float32)
 
         mask = self.mask_processor.preprocess(
-            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+            mask_image,
+            height=height,
+            width=width,
+            resize_mode=resize_mode,
+            crops_coords=crops_coords,
         )
 
         if masked_image_latents is not None:
@@ -1984,7 +2320,10 @@ def denoising_value_valid(dnv):
             # default case for runwayml/stable-diffusion-inpainting
             num_channels_mask = mask.shape[1]
             num_channels_masked_image = masked_image_latents.shape[1]
-            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+            if (
+                num_channels_latents + num_channels_mask + num_channels_masked_image
+                != self.unet.config.in_channels
+            ):
                 raise ValueError(
                     f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
                     f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
@@ -1999,7 +2338,8 @@ def denoising_value_valid(dnv):
         # 8.1 Prepare extra step kwargs.
         extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
-        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be
+        # moved out of the pipeline
         height, width = latents.shape[-2:]
         height = height * self.vae_scale_factor
         width = width * self.vae_scale_factor
@@ -2031,19 +2371,25 @@ def denoising_value_valid(dnv):
             dtype=prompt_embeds.dtype,
             text_encoder_projection_dim=text_encoder_projection_dim,
         )
-        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+        add_time_ids = add_time_ids.repeat(
+            batch_size * num_images_per_prompt, 1)
 
         if self.do_classifier_free_guidance:
-            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
-            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
-            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+            prompt_embeds = torch.cat(
+                [negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat(
+                [negative_pooled_prompt_embeds, add_text_embeds], dim=0
+            )
+            add_neg_time_ids = add_neg_time_ids.repeat(
+                batch_size * num_images_per_prompt, 1
+            )
             add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
 
-
         ###########
         if self.do_self_attention_redirection_guidance:
             prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
-            add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+            add_text_embeds = torch.cat(
+                [add_text_embeds, add_text_embeds], dim=0)
             add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
             add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
         ############
@@ -2064,13 +2410,24 @@ def denoising_value_valid(dnv):
         # apply AAS to modify the attention module
         if self.do_self_attention_redirection_guidance:
             self._AAS_end_step = int(strength * self._num_timesteps)
-            layer_idx=list(range(self._AAS_start_layer, self._AAS_end_layer))
-            editor = AAS_XL(self._AAS_start_step, self._AAS_end_step, self._AAS_start_layer, self._AAS_end_layer, layer_idx= layer_idx, mask=mask_image,model_type="SDXL",ss_steps=self._ss_steps,ss_scale=self._ss_scale)
+            layer_idx = list(range(self._AAS_start_layer, self._AAS_end_layer))
+            editor = AAS_XL(
+                self._AAS_start_step,
+                self._AAS_end_step,
+                self._AAS_start_layer,
+                self._AAS_end_layer,
+                layer_idx=layer_idx,
+                mask=mask_image,
+                model_type="SDXL",
+                ss_steps=self._ss_steps,
+                ss_scale=self._ss_scale,
+            )
             self.regiter_attention_editor_diffusers(self.unet, editor)
 
-
         # 11. Denoising loop
-        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+        num_warmup_steps = max(
+            len(timesteps) - num_inference_steps * self.scheduler.order, 0
+        )
 
         if (
             self.denoising_end is not None
@@ -2083,20 +2440,26 @@ def denoising_value_valid(dnv):
                 f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
                 + f" {self.denoising_end} when using type float."
             )
-        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+        elif self.denoising_end is not None and denoising_value_valid(
+            self.denoising_end
+        ):
             discrete_timestep_cutoff = int(
                 round(
                     self.scheduler.config.num_train_timesteps
                     - (self.denoising_end * self.scheduler.config.num_train_timesteps)
                 )
             )
-            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+            num_inference_steps = len(
+                list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
+            )
             timesteps = timesteps[:num_inference_steps]
 
         # 11.1 Optionally get Guidance Scale Embedding
         timestep_cond = None
         if self.unet.config.time_cond_proj_dim is not None:
-            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
+                batch_size * num_images_per_prompt
+            )
             timestep_cond = self.get_guidance_scale_embedding(
                 guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
             ).to(device=device, dtype=latents.dtype)
@@ -2109,21 +2472,37 @@ def denoising_value_valid(dnv):
                     continue
 
                 # expand the latents if we are doing classifier free guidance
-                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                latent_model_input = (
+                    torch.cat([latents] * 2)
+                    if self.do_classifier_free_guidance
+                    else latents
+                )
 
-                #removal guidance
-                latent_model_input = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
-                #latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
-                
-                # concat latents, mask, masked_image_latents in the channel dimension
-                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-                #latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+                # removal guidance
+                latent_model_input = (
+                    torch.cat([latents] * 2)
+                    if self.do_self_attention_redirection_guidance
+                    else latents
+                )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+                # latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+
+                # concat latents, mask, masked_image_latents in the channel
+                # dimension
+                latent_model_input = self.scheduler.scale_model_input(
+                    latent_model_input, t
+                )
+                # latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
 
                 if num_channels_unet == 9:
-                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+                    latent_model_input = torch.cat(
+                        [latent_model_input, mask, masked_image_latents], dim=1
+                    )
 
                 # predict the noise residual
-                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+                added_cond_kwargs = {
+                    "text_embeds": add_text_embeds,
+                    "time_ids": add_time_ids,
+                }
                 if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
                     added_cond_kwargs["image_embeds"] = image_embeds
                 noise_pred = self.unet(
@@ -2145,16 +2524,22 @@ def denoising_value_valid(dnv):
                 # perform guidance
                 if self.do_classifier_free_guidance:
                     noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
-                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (
+                        noise_pred_text - noise_pred_uncond
+                    )
 
                 if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                     # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
-                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
-
-
+                    noise_pred = rescale_noise_cfg(
+                        noise_pred,
+                        noise_pred_text,
+                        guidance_rescale=self.guidance_rescale,
+                    )
 
                 # compute the previous noisy sample x_t -> x_t-1
-                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+                latents = self.scheduler.step(
+                    noise_pred, t, latents, **extra_step_kwargs, return_dict=False
+                )[0]
 
                 if num_channels_unet == 4:
                     init_latents_proper = image_latents
@@ -2166,28 +2551,42 @@ def denoising_value_valid(dnv):
                     if i < len(timesteps) - 1:
                         noise_timestep = timesteps[i + 1]
                         init_latents_proper = self.scheduler.add_noise(
-                            init_latents_proper, noise, torch.tensor([noise_timestep])
+                            init_latents_proper, noise, torch.tensor(
+                                [noise_timestep])
                         )
 
-                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+                    latents = (
+                        1 - init_mask
+                    ) * init_latents_proper + init_mask * latents
 
                 if callback_on_step_end is not None:
                     callback_kwargs = {}
                     for k in callback_on_step_end_tensor_inputs:
                         callback_kwargs[k] = locals()[k]
-                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+                    callback_outputs = callback_on_step_end(
+                        self, i, t, callback_kwargs)
 
                     latents = callback_outputs.pop("latents", latents)
-                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
-                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-                    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+                    prompt_embeds = callback_outputs.pop(
+                        "prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop(
+                        "negative_prompt_embeds", negative_prompt_embeds
+                    )
+                    add_text_embeds = callback_outputs.pop(
+                        "add_text_embeds", add_text_embeds
+                    )
                     negative_pooled_prompt_embeds = callback_outputs.pop(
                         "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
                     )
-                    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
-                    add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+                    add_time_ids = callback_outputs.pop(
+                        "add_time_ids", add_time_ids)
+                    add_neg_time_ids = callback_outputs.pop(
+                        "add_neg_time_ids", add_neg_time_ids
+                    )
                     mask = callback_outputs.pop("mask", mask)
-                    masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+                    masked_image_latents = callback_outputs.pop(
+                        "masked_image_latents", masked_image_latents
+                    )
 
                 # call the callback, if provided
                 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -2201,25 +2600,42 @@ def denoising_value_valid(dnv):
 
         if not output_type == "latent":
             # make sure the VAE is in float32 mode, as it overflows in float16
-            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            needs_upcasting = (
+                self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            )
             latents = latents[-1:]
 
             if needs_upcasting:
                 self.upcast_vae()
-                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+                latents = latents.to(
+                    next(iter(self.vae.post_quant_conv.parameters())).dtype
+                )
 
             # unscale/denormalize the latents
             # denormalize with the mean and std if available and not None
-            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
-            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+            has_latents_mean = (
+                hasattr(self.vae.config, "latents_mean")
+                and self.vae.config.latents_mean is not None
+            )
+            has_latents_std = (
+                hasattr(self.vae.config, "latents_std")
+                and self.vae.config.latents_std is not None
+            )
             if has_latents_mean and has_latents_std:
                 latents_mean = (
-                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                    torch.tensor(self.vae.config.latents_mean)
+                    .view(1, 4, 1, 1)
+                    .to(latents.device, latents.dtype)
                 )
                 latents_std = (
-                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                    torch.tensor(self.vae.config.latents_std)
+                    .view(1, 4, 1, 1)
+                    .to(latents.device, latents.dtype)
+                )
+                latents = (
+                    latents * latents_std / self.vae.config.scaling_factor
+                    + latents_mean
                 )
-                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
             else:
                 latents = latents / self.vae.config.scaling_factor
 
@@ -2235,10 +2651,16 @@ def denoising_value_valid(dnv):
         if self.watermark is not None:
             image = self.watermark.apply_watermark(image)
 
-        image = self.image_processor.postprocess(image, output_type=output_type)
+        image = self.image_processor.postprocess(
+            image, output_type=output_type)
 
         if padding_mask_crop is not None:
-            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+            image = [
+                self.image_processor.apply_overlay(
+                    mask_image, original_image, i, crops_coords
+                )
+                for i in image
+            ]
 
         # Offload all models
         self.maybe_free_model_hooks()
From 717e2c711963e587678736b9f572317b90ae235d Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Wed, 15 Jan 2025 17:04:41 +0800
Subject: [PATCH 3/7] make style and add example output
---
 examples/community/README.md                  |   92 +-
 ...ne_stable_diffusion_xl_attentive_eraser.py | 2285 +++++++++++++++++
 2 files changed, 2375 insertions(+), 2 deletions(-)
 mode change 100755 => 100644 examples/community/README.md
 create mode 100644 examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
diff --git a/examples/community/README.md b/examples/community/README.md
old mode 100755
new mode 100644
index c7c40c46ef2d..02777636d2d3
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
 PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
 | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
 | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3)|
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
@@ -4585,8 +4586,8 @@ image = pipe(
 ```
 
 |  |  |  |
-| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
-| Gradient                                                                                   | Input                                                                                   | Output                                                                                   |
+| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
+| Gradient                                                                                     | Input                                                                                     | Output                                                                                     |
 
 A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
 
@@ -4634,6 +4635,93 @@ make_image_grid(image, rows=1, cols=len(image))
 # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
 ```
 
+### Stable Diffusion XL Attentive Eraser Pipeline
+
+
+**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
+
+#### Key Features
+
+- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
+- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
+- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
+
+#### How to Use
+To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
+```py
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers.utils import load_image
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+dtype = torch.float16
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
+
+scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+pipeline = DiffusionPipeline.from_pretrained(
+    "stabilityai/stable-diffusion-xl-base-1.0",
+    custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+    scheduler=scheduler,
+    variant="fp16",
+    use_safetensors=True,
+    torch_dtype=dtype,
+).to(device)
+
+
+def preprocess_image(image_path, device):
+    image = to_tensor((load_image(image_path)))
+    image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+    if image.shape[1] != 3:
+        image = image.expand(-1, 3, -1, -1)
+        image = F.interpolate(image, (1024, 1024))
+        image = image.to(dtype).to(device)
+        return image
+
+def preprocess_mask(mask_path, device):
+    mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+    mask = mask.unsqueeze_(0).float()  # 0 or 1
+    mask = F.interpolate(mask, (1024, 1024))
+    mask = gaussian_blur(mask, kernel_size=(77, 77))
+    mask[mask < 0.1] = 0
+    mask[mask >= 0.1] = 1
+    mask = mask.to(dtype).to(device)
+    return mask
+
+prompt = "" # Set prompt to null
+seed=123 
+generator = torch.Generator(device=device).manual_seed(seed)
+source_image_path = "./path-to-image.png"
+mask_path = "./path-to-mask.png"
+source_image = preprocess_image(source_image_path, device)
+mask = preprocess_mask(mask_path, device)
+
+image = pipeline(
+    prompt=prompt, 
+    image=source_image,
+    mask_image=mask,
+    height=1024,
+    width=1024,
+    AAS=True, # enable AAS
+    strength=0.8, # inpainting strength
+    rm_guidance_scale=9, # removal guidance scale
+    ss_steps = 9, # similarity suppression steps
+    ss_scale = 0.3, # similarity suppression scale
+    AAS_start_step=0, # AAS start step
+    AAS_start_layer=34, # AAS start layer
+    AAS_end_layer=70, # AAS end layer
+    num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+    generator=generator,
+    guidance_scale=1,
+).images[0]
+image.save('./removed_img.png')
+print("Object removal completed")
+```
+
+
+
 # Perturbed-Attention Guidance
 
 [Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
new file mode 100644
index 000000000000..cebd80c8bfa1
--- /dev/null
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -0,0 +1,2249 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+from tqdm import tqdm
+import numpy as np
+import PIL.Image
+from PIL import Image
+import torch
+from transformers import (
+    CLIPImageProcessor,
+    CLIPTextModel,
+    CLIPTextModelWithProjection,
+    CLIPTokenizer,
+    CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+    FromSingleFileMixin,
+    IPAdapterMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+    AttnProcessor2_0,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    deprecate,
+    is_invisible_watermark_available,
+    is_torch_xla_available,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+
+
+if is_invisible_watermark_available():
+    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+
+    XLA_AVAILABLE = True
+else:
+    XLA_AVAILABLE = False
+
+from einops import rearrange, repeat
+import torch.nn as nn
+import torch.nn.functional as F
+import os
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> import torch
+        >>> from diffusers import StableDiffusionXLInpaintPipeline
+        >>> from diffusers.utils import load_image
+
+        >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+        ...     "stabilityai/stable-diffusion-xl-base-1.0",
+        ...     torch_dtype=torch.float16,
+        ...     variant="fp16",
+        ...     use_safetensors=True,
+        ... )
+        >>> pipe.to("cuda")
+
+        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+        >>> init_image = load_image(img_url).convert("RGB")
+        >>> mask_image = load_image(mask_url).convert("RGB")
+
+        >>> prompt = "A majestic tiger sitting on a bench"
+        >>> image = pipe(
+        ...     prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+        ... ).images[0]
+        ```
+"""
+
+class AttentionBase:
+    def __init__(self):
+        self.cur_step = 0
+        self.num_att_layers = -1
+        self.cur_att_layer = 0
+
+    def after_step(self):
+        pass
+
+    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+        self.cur_att_layer += 1
+        if self.cur_att_layer == self.num_att_layers:
+            self.cur_att_layer = 0
+            self.cur_step += 1
+            # after step
+            self.after_step()
+        return out
+
+    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        out = torch.einsum('b i j, b j d -> b i d', attn, v)
+        out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
+        return out
+
+    def reset(self):
+        self.cur_step = 0
+        self.cur_att_layer = 0
+
+class AAS_XL(AttentionBase):
+    MODEL_TYPE = {
+        "SD": 16,
+        "SDXL": 70
+    }
+    def __init__(self,  start_step=4, end_step= 50, start_layer=10, end_layer=16,layer_idx=None, step_idx=None, total_steps=50,  mask=None, model_type="SD",ss_steps=9,ss_scale=1.0):
+        """
+        Args:
+            start_step: the step to start AAS
+            start_layer: the layer to start AAS
+            layer_idx: list of the layers to apply AAS
+            step_idx: list the steps to apply AAS
+            total_steps: the total number of steps
+            mask: source mask with shape (h, w)
+            model_type: the model type, SD or SDXL
+        """
+        super().__init__()
+        self.total_steps = total_steps
+        self.total_layers = self.MODEL_TYPE.get(model_type, 16)
+        self.start_step = start_step
+        self.end_step = end_step
+        self.start_layer = start_layer
+        self.end_layer = end_layer
+        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
+        self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+        self.mask = mask  # mask with shape (1, 1 ,h, w)
+        self.ss_steps = ss_steps
+        self.ss_scale = ss_scale
+        print("AAS at denoising steps: ", self.step_idx)
+        print("AAS at U-Net layers: ", self.layer_idx)
+        print("start AAS")
+        self.mask_16 = F.max_pool2d(mask,(1024//16,1024//16)).round().squeeze().squeeze()
+        self.mask_32 = F.max_pool2d(mask,(1024//32,1024//32)).round().squeeze().squeeze()
+        self.mask_64 = F.max_pool2d(mask,(1024//64,1024//64)).round().squeeze().squeeze()
+        self.mask_128 = F.max_pool2d(mask,(1024//128,1024//128)).round().squeeze().squeeze()
+    
+    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,is_mask_attn, mask, **kwargs):
+        B = q.shape[0] // num_heads
+        if is_mask_attn:
+            mask_flatten = mask.flatten(0)
+            if self.cur_step <= self.ss_steps:                                                                                                                                                                                                                                           
+                # background
+                sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min) 
+
+                #object
+                sim_fg = self.ss_scale*sim
+                sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+                sim = torch.cat([sim_fg, sim_bg], dim=0)
+            else:
+                sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+        attn = sim.softmax(-1)
+        if len(attn) == 2 * len(v):
+            v = torch.cat([v] * 2)
+        out = torch.einsum("h i j, h j d -> h i d", attn, v)
+        out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+        return out
+
+    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        """
+        Attention forward function
+        """
+        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+        B = q.shape[0] // num_heads // 2
+        H = W = int(np.sqrt(q.shape[1]))
+        if H == 16:
+            mask = self.mask_16.to(sim.device)
+        elif H == 32:
+            mask = self.mask_32.to(sim.device)
+        elif H == 64:
+            mask = self.mask_64.to(sim.device)
+        else:
+            mask = self.mask_128.to(sim.device)
+
+
+        q_wo, q_w = q.chunk(2)
+        k_wo, k_w = k.chunk(2)
+        v_wo, v_w = v.chunk(2)
+        sim_wo, sim_w = sim.chunk(2)
+        attn_wo, attn_w = attn.chunk(2)
+
+        out_source = self.attn_batch(q_wo, k_wo, v_wo, sim_wo, attn_wo, is_cross, place_in_unet, num_heads,is_mask_attn=False,mask=None,**kwargs)
+        out_target = self.attn_batch(q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask = mask,**kwargs)
+
+        if self.mask is not None:
+            if out_target.shape[0] == 2:
+                out_target_fg, out_target_bg = out_target.chunk(2, 0)
+                mask = mask.reshape(-1, 1)  # (hw, 1)
+                out_target = out_target_fg * mask + out_target_bg * (1 - mask)
+            else:
+                out_target = out_target
+        
+        out = torch.cat([out_source, out_target], dim=0)
+        return out
+    
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+    """
+    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+    """
+    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+    # rescale the results from guidance (fixes overexposure)
+    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+    return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+    # preprocess mask
+    if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+        mask = [mask]
+
+    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+        mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+        mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+        mask = mask.astype(np.float32) / 255.0
+    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+    mask = torch.from_numpy(mask)
+    return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+    """
+    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+    ``image`` and ``1`` for the ``mask``.
+
+    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+    Args:
+        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+    Raises:
+        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+            (ot the other way around).
+
+    Returns:
+        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+            dimensions: ``batch x channels x height x width``.
+    """
+
+    # checkpoint. TOD(Yiyi) - need to clean this up later
+    deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+    deprecate(
+        "prepare_mask_and_masked_image",
+        "0.30.0",
+        deprecation_message,
+    )
+    if image is None:
+        raise ValueError("`image` input cannot be undefined.")
+
+    if mask is None:
+        raise ValueError("`mask_image` input cannot be undefined.")
+
+    if isinstance(image, torch.Tensor):
+        if not isinstance(mask, torch.Tensor):
+            mask = mask_pil_to_torch(mask, height, width)
+
+        if image.ndim == 3:
+            image = image.unsqueeze(0)
+
+        # Batch and add channel dim for single mask
+        if mask.ndim == 2:
+            mask = mask.unsqueeze(0).unsqueeze(0)
+
+        # Batch single mask or add channel dim
+        if mask.ndim == 3:
+            # Single batched mask, no channel dim or single mask not batched but channel dim
+            if mask.shape[0] == 1:
+                mask = mask.unsqueeze(0)
+
+            # Batched masks no channel dim
+            else:
+                mask = mask.unsqueeze(1)
+
+        assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+        # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+        assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+        # Check image is in [-1, 1]
+        # if image.min() < -1 or image.max() > 1:
+        #    raise ValueError("Image should be in [-1, 1] range")
+
+        # Check mask is in [0, 1]
+        if mask.min() < 0 or mask.max() > 1:
+            raise ValueError("Mask should be in [0, 1] range")
+
+        # Binarize mask
+        mask[mask < 0.5] = 0
+        mask[mask >= 0.5] = 1
+
+        # Image as float32
+        image = image.to(dtype=torch.float32)
+    elif isinstance(mask, torch.Tensor):
+        raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+    else:
+        # preprocess image
+        if isinstance(image, (PIL.Image.Image, np.ndarray)):
+            image = [image]
+        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+            # resize all images w.r.t passed height an width
+            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+            image = [np.array(i.convert("RGB"))[None, :] for i in image]
+            image = np.concatenate(image, axis=0)
+        elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+            image = np.concatenate([i[None, :] for i in image], axis=0)
+
+        image = image.transpose(0, 3, 1, 2)
+        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+        mask = mask_pil_to_torch(mask, height, width)
+        mask[mask < 0.5] = 0
+        mask[mask >= 0.5] = 1
+
+    if image.shape[1] == 4:
+        # images are in latent space and thus can't
+        # be masked set masked_image to None
+        # we assume that the checkpoint is not an inpainting
+        # checkpoint. TOD(Yiyi) - need to clean this up later
+        masked_image = None
+    else:
+        masked_image = image * (mask < 0.5)
+
+    # n.b. ensure backwards compatibility as old function does not return image
+    if return_image:
+        return mask, masked_image, image
+
+    return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+        return encoder_output.latent_dist.sample(generator)
+    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+        return encoder_output.latent_dist.mode()
+    elif hasattr(encoder_output, "latents"):
+        return encoder_output.latents
+    else:
+        raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used,
+            `timesteps` must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+                must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+class StableDiffusionXL_AE_Pipeline(
+    DiffusionPipeline,
+    StableDiffusionMixin,
+    TextualInversionLoaderMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    FromSingleFileMixin,
+    IPAdapterMixin,
+):
+    r"""
+    Pipeline for object removal using Stable Diffusion XL.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    The pipeline also inherits the following loading methods:
+        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion XL uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        text_encoder_2 ([` CLIPTextModelWithProjection`]):
+            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+            specifically the
+            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+            variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        tokenizer_2 (`CLIPTokenizer`):
+            Second Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+            Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+            of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+            `stabilityai/stable-diffusion-xl-base-1-0`.
+        add_watermarker (`bool`, *optional*):
+            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+            watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+            watermarker will be used.
+    """
+
+    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+
+    _optional_components = [
+        "tokenizer",
+        "tokenizer_2",
+        "text_encoder",
+        "text_encoder_2",
+        "image_encoder",
+        "feature_extractor",
+    ]
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+        "add_text_embeds",
+        "add_time_ids",
+        "negative_pooled_prompt_embeds",
+        "add_neg_time_ids",
+        "mask",
+        "masked_image_latents",
+    ]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        text_encoder_2: CLIPTextModelWithProjection,
+        tokenizer: CLIPTokenizer,
+        tokenizer_2: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        scheduler: KarrasDiffusionSchedulers,
+        image_encoder: CLIPVisionModelWithProjection = None,
+        feature_extractor: CLIPImageProcessor = None,
+        requires_aesthetics_score: bool = False,
+        force_zeros_for_empty_prompt: bool = True,
+        add_watermarker: Optional[bool] = None,
+    ):
+        super().__init__()
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            text_encoder_2=text_encoder_2,
+            tokenizer=tokenizer,
+            tokenizer_2=tokenizer_2,
+            unet=unet,
+            image_encoder=image_encoder,
+            feature_extractor=feature_extractor,
+            scheduler=scheduler,
+        )
+        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.mask_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+        )
+
+        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+        if add_watermarker:
+            self.watermark = StableDiffusionXLWatermarker()
+        else:
+            self.watermark = None
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        if not isinstance(image, torch.Tensor):
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+        if output_hidden_states:
+            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_enc_hidden_states = self.image_encoder(
+                torch.zeros_like(image), output_hidden_states=True
+            ).hidden_states[-2]
+            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
+            return image_enc_hidden_states, uncond_image_enc_hidden_states
+        else:
+            image_embeds = self.image_encoder(image).image_embeds
+            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_embeds = torch.zeros_like(image_embeds)
+
+            return image_embeds, uncond_image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+    ):
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+                )
+
+            image_embeds = []
+            for single_ip_adapter_image, image_proj_layer in zip(
+                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+            ):
+                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+                single_image_embeds, single_negative_image_embeds = self.encode_image(
+                    single_ip_adapter_image, device, 1, output_hidden_state
+                )
+                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+                single_negative_image_embeds = torch.stack(
+                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
+                )
+
+                if do_classifier_free_guidance:
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                    single_image_embeds = single_image_embeds.to(device)
+
+                image_embeds.append(single_image_embeds)
+        else:
+            repeat_dims = [1]
+            image_embeds = []
+            for single_image_embeds in ip_adapter_image_embeds:
+                if do_classifier_free_guidance:
+                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                    single_negative_image_embeds = single_negative_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+                    )
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                else:
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                image_embeds.append(single_image_embeds)
+
+        return image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+    def encode_prompt(
+        self,
+        prompt: str,
+        prompt_2: Optional[str] = None,
+        device: Optional[torch.device] = None,
+        num_images_per_prompt: int = 1,
+        do_classifier_free_guidance: bool = True,
+        negative_prompt: Optional[str] = None,
+        negative_prompt_2: Optional[str] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        clip_skip: Optional[int] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            lora_scale (`float`, *optional*):
+                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+        """
+        device = device or self._execution_device
+
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+            # dynamically adjust the LoRA scale
+            if self.text_encoder is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder, lora_scale)
+
+            if self.text_encoder_2 is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder_2, lora_scale)
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        if prompt_embeds is None:
+            prompt_2 = prompt_2 or prompt
+            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+            # textual inversion: process multi-vector tokens if necessary
+            prompt_embeds_list = []
+            prompts = [prompt, prompt_2]
+            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+                text_inputs = tokenizer(
+                    prompt,
+                    padding="max_length",
+                    max_length=tokenizer.model_max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                text_input_ids = text_inputs.input_ids
+                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                    text_input_ids, untruncated_ids
+                ):
+                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                    logger.warning(
+                        "The following part of your input was truncated because CLIP can only handle sequences up to"
+                        f" {tokenizer.model_max_length} tokens: {removed_text}"
+                    )
+
+                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                pooled_prompt_embeds = prompt_embeds[0]
+                if clip_skip is None:
+                    prompt_embeds = prompt_embeds.hidden_states[-2]
+                else:
+                    # "2" because SDXL always indexes from the penultimate layer.
+                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+                prompt_embeds_list.append(prompt_embeds)
+
+            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+        # get unconditional embeddings for classifier free guidance
+        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+            negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+        elif do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+            # normalize str to list
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+            negative_prompt_2 = (
+                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+            )
+
+            uncond_tokens: List[str]
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = [negative_prompt, negative_prompt_2]
+
+            negative_prompt_embeds_list = []
+            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+                max_length = prompt_embeds.shape[1]
+                uncond_input = tokenizer(
+                    negative_prompt,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                negative_prompt_embeds = text_encoder(
+                    uncond_input.input_ids.to(device),
+                    output_hidden_states=True,
+                )
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+                negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+        if self.text_encoder_2 is not None:
+            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+        else:
+            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            if self.text_encoder_2 is not None:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+            else:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+        if do_classifier_free_guidance:
+            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+                bs_embed * num_images_per_prompt, -1
+            )
+
+        if self.text_encoder is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder, lora_scale)
+
+        if self.text_encoder_2 is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def check_inputs(
+        self,
+        prompt,
+        prompt_2,
+        image,
+        mask_image,
+        height,
+        width,
+        strength,
+        callback_steps,
+        output_type,
+        negative_prompt=None,
+        negative_prompt_2=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        ip_adapter_image=None,
+        ip_adapter_image_embeds=None,
+        callback_on_step_end_tensor_inputs=None,
+        padding_mask_crop=None,
+    ):
+        if strength < 0 or strength > 1:
+            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt_2 is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+            raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+        if padding_mask_crop is not None:
+            if not isinstance(image, PIL.Image.Image):
+                raise ValueError(
+                    f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+                )
+            if not isinstance(mask_image, PIL.Image.Image):
+                raise ValueError(
+                    f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+                    f" {type(mask_image)}."
+                )
+            if output_type != "pil":
+                raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+            raise ValueError(
+                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+            )
+
+        if ip_adapter_image_embeds is not None:
+            if not isinstance(ip_adapter_image_embeds, list):
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+                )
+            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+                )
+
+    def prepare_latents(
+        self,
+        batch_size,
+        num_channels_latents,
+        height,
+        width,
+        dtype,
+        device,
+        generator,
+        latents=None,
+        image=None,
+        timestep=None,
+        is_strength_max=True,
+        add_noise=True,
+        return_noise=False,
+        return_image_latents=False,
+    ):
+        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if (image is None or timestep is None) and not is_strength_max:
+            raise ValueError(
+                "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+                "However, either the image or the noise timestep has not been provided."
+            )
+
+        if image.shape[1] == 4:
+            image_latents = image.to(device=device, dtype=dtype)
+            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+        elif return_image_latents or (latents is None and not is_strength_max):
+            image = image.to(device=device, dtype=dtype)
+            image_latents = self._encode_vae_image(image=image, generator=generator)
+            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+        if latents is None and add_noise:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            # if strength is 1. then initialise the latents to noise, else initial to image + noise
+            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+            # if pure noise then scale the initial latents by the  Scheduler's init sigma
+            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+        elif add_noise:
+            noise = latents.to(device)
+            latents = noise * self.scheduler.init_noise_sigma
+        else:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            latents = image_latents.to(device)
+
+        outputs = (latents,)
+
+        if return_noise:
+            outputs += (noise,)
+
+        if return_image_latents:
+            outputs += (image_latents,)
+
+        return outputs
+
+    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+        dtype = image.dtype
+        if self.vae.config.force_upcast:
+            image = image.float()
+            self.vae.to(dtype=torch.float32)
+
+        if isinstance(generator, list):
+            image_latents = [
+                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+                for i in range(image.shape[0])
+            ]
+            image_latents = torch.cat(image_latents, dim=0)
+        else:
+            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+        if self.vae.config.force_upcast:
+            self.vae.to(dtype)
+
+        image_latents = image_latents.to(dtype)
+        image_latents = self.vae.config.scaling_factor * image_latents
+
+        return image_latents
+
+    def prepare_mask_latents(
+        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+    ):
+        # resize the mask to latents shape as we concatenate the mask to the latents
+        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+        # and half precision
+        #mask = torch.nn.functional.interpolate(
+        #    mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+        #)
+        mask = torch.nn.functional.max_pool2d(mask, (8,8)).round()
+        mask = mask.to(device=device, dtype=dtype)
+
+        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+        if mask.shape[0] < batch_size:
+            if not batch_size % mask.shape[0] == 0:
+                raise ValueError(
+                    "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+                    f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+                    " of masks that you pass is divisible by the total requested batch size."
+                )
+            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+        if masked_image is not None and masked_image.shape[1] == 4:
+            masked_image_latents = masked_image
+        else:
+            masked_image_latents = None
+
+        if masked_image is not None:
+            if masked_image_latents is None:
+                masked_image = masked_image.to(device=device, dtype=dtype)
+                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+            if masked_image_latents.shape[0] < batch_size:
+                if not batch_size % masked_image_latents.shape[0] == 0:
+                    raise ValueError(
+                        "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+                        f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+                        " Make sure the number of images that you pass is divisible by the total requested batch size."
+                    )
+                masked_image_latents = masked_image_latents.repeat(
+                    batch_size // masked_image_latents.shape[0], 1, 1, 1
+                )
+
+            masked_image_latents = (
+                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+            )
+
+            # aligning device to prevent device errors when concating it with the latent model input
+            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+        return mask, masked_image_latents
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+        # get the original timestep using init_timestep
+        if denoising_start is None:
+            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+            t_start = max(num_inference_steps - init_timestep, 0)
+        else:
+            t_start = 0
+
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+        # Strength is irrelevant if we directly request a timestep to start at;
+        # that is, strength is determined by the denoising_start instead.
+        if denoising_start is not None:
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (denoising_start * self.scheduler.config.num_train_timesteps)
+                )
+            )
+
+            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+                # if the scheduler is a 2nd order scheduler we might have to do +1
+                # because `num_inference_steps` might be even given that every timestep
+                # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+                # mean that we cut the timesteps in the middle of the denoising step
+                # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+                num_inference_steps = num_inference_steps + 1
+
+            # because t_n+1 >= t_n, we slice the timesteps starting from the end
+            timesteps = timesteps[-num_inference_steps:]
+            return timesteps, num_inference_steps
+
+        return timesteps, num_inference_steps - t_start
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+    def _get_add_time_ids(
+        self,
+        original_size,
+        crops_coords_top_left,
+        target_size,
+        aesthetic_score,
+        negative_aesthetic_score,
+        negative_original_size,
+        negative_crops_coords_top_left,
+        negative_target_size,
+        dtype,
+        text_encoder_projection_dim=None,
+    ):
+        if self.config.requires_aesthetics_score:
+            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+            add_neg_time_ids = list(
+                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+            )
+        else:
+            add_time_ids = list(original_size + crops_coords_top_left + target_size)
+            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+        passed_add_embed_dim = (
+            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+        )
+        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+        if (
+            expected_add_embed_dim > passed_add_embed_dim
+            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+            )
+        elif (
+            expected_add_embed_dim < passed_add_embed_dim
+            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+            )
+        elif expected_add_embed_dim != passed_add_embed_dim:
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+            )
+
+        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+        return add_time_ids, add_neg_time_ids
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+    def upcast_vae(self):
+        dtype = self.vae.dtype
+        self.vae.to(dtype=torch.float32)
+        use_torch_2_0_or_xformers = isinstance(
+            self.vae.decoder.mid_block.attentions[0].processor,
+            (
+                AttnProcessor2_0,
+                XFormersAttnProcessor,
+                LoRAXFormersAttnProcessor,
+                LoRAAttnProcessor2_0,
+            ),
+        )
+        # if xformers or torch_2_0 is used attention block does not need
+        # to be in float32 which can save lots of memory
+        if use_torch_2_0_or_xformers:
+            self.vae.post_quant_conv.to(dtype)
+            self.vae.decoder.conv_in.to(dtype)
+            self.vae.decoder.mid_block.to(dtype)
+
+    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+        """
+        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+        Args:
+            timesteps (`torch.Tensor`):
+                generate embedding vectors at these timesteps
+            embedding_dim (`int`, *optional*, defaults to 512):
+                dimension of the embeddings to generate
+            dtype:
+                data type of the generated embeddings
+
+        Returns:
+            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+        """
+        assert len(w.shape) == 1
+        w = w * 1000.0
+
+        half_dim = embedding_dim // 2
+        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+        emb = w.to(dtype)[:, None] * emb[None, :]
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+        if embedding_dim % 2 == 1:  # zero pad
+            emb = torch.nn.functional.pad(emb, (0, 1))
+        assert emb.shape == (w.shape[0], embedding_dim)
+        return emb
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def guidance_rescale(self):
+        return self._guidance_rescale
+
+    @property
+    def clip_skip(self):
+        return self._clip_skip
+
+
+    @property
+    def do_self_attention_redirection_guidance(self): #SARG
+        return self._rm_guidance_scale > 1 and self._AAS
+    
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None and self.do_self_attention_redirection_guidance==False #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+    
+    @property
+    def cross_attention_kwargs(self):
+        return self._cross_attention_kwargs
+
+    @property
+    def denoising_end(self):
+        return self._denoising_end
+
+    @property
+    def denoising_start(self):
+        return self._denoising_start
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    def image2latent(self, image: torch.Tensor, generator: torch.Generator):
+        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        if type(image) is Image:
+            image = np.array(image)
+            image = torch.from_numpy(image).float() / 127.5 - 1
+            image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
+        # input image density range [-1, 1]
+        #latents = self.vae.encode(image)['latent_dist'].mean
+        latents = self._encode_vae_image(image, generator)
+        #latents = retrieve_latents(self.vae.encode(image))
+        #latents = latents * self.vae.config.scaling_factor
+        return latents
+    
+    def next_step(
+        self,
+        model_output: torch.FloatTensor,
+        timestep: int,
+        x: torch.FloatTensor,
+        eta=0.,
+        verbose=False
+    ):
+        """
+        Inverse sampling for DDIM Inversion
+        """
+        if verbose:
+            print("timestep: ", timestep)
+        next_step = timestep
+        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
+        beta_prod_t = 1 - alpha_prod_t
+        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+        pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
+        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
+        return x_next, pred_x0
+    
+    @torch.no_grad()
+    def invert(
+        self,
+        image: torch.Tensor,
+        prompt,
+        num_inference_steps=50,
+        eta=0.0,
+        original_size: Tuple[int, int] = None,
+        target_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        return_intermediates=False,
+        **kwds):
+        """
+        invert a real image into noise map with determinisc DDIM inversion
+        """
+        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        batch_size = image.shape[0]
+        if isinstance(prompt, list):
+            if batch_size == 1:
+                image = image.expand(len(prompt), -1, -1, -1)
+        elif isinstance(prompt, str):
+            if batch_size > 1:
+                prompt = [prompt] * batch_size
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        prompt_2 = prompt
+        prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+        # textual inversion: process multi-vector tokens if necessary
+        prompt_embeds_list = []
+        prompts = [prompt, prompt_2]
+        for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+            if isinstance(self, TextualInversionLoaderMixin):
+                prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+            text_inputs = tokenizer(
+                prompt,
+                padding="max_length",
+                max_length=tokenizer.model_max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+
+            text_input_ids = text_inputs.input_ids
+            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                text_input_ids, untruncated_ids
+            ):
+                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                logger.warning(
+                    "The following part of your input was truncated because CLIP can only handle sequences up to"
+                    f" {tokenizer.model_max_length} tokens: {removed_text}"
+                )
+
+            prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+
+            # We are only ALWAYS interested in the pooled output of the final text encoder
+            pooled_prompt_embeds = prompt_embeds[0]
+            prompt_embeds = prompt_embeds.hidden_states[-2]
+            prompt_embeds_list.append(prompt_embeds)
+        prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+        prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
+
+        # define initial latents
+        latents = self.image2latent(image,generator=None)
+
+        start_latents = latents
+        height, width = latents.shape[-2:]
+        height = height * self.vae_scale_factor
+        width = width * self.vae_scale_factor
+
+        original_size = (height, width)
+        target_size = (height, width)
+        negative_original_size = original_size
+        negative_target_size = target_size
+
+        add_text_embeds = pooled_prompt_embeds
+        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+
+        add_time_ids = add_time_ids.repeat(batch_size, 1).to(DEVICE)
+
+        # interative sampling
+        self.scheduler.set_timesteps(num_inference_steps)
+        latents_list = [latents]
+        pred_x0_list = []
+        #for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+        for i, t in enumerate(reversed(self.scheduler.timesteps)):
+            model_inputs = latents
+
+            # predict the noise
+            added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+            noise_pred = self.unet(model_inputs, t, encoder_hidden_states=prompt_embeds,added_cond_kwargs=added_cond_kwargs).sample
+
+            # compute the previous noise sample x_t-1 -> x_t
+            latents, pred_x0 = self.next_step(noise_pred, t, latents)
+            """             
+            if t >= 1 and t < 41:
+                latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
+            else:
+                latents, pred_x0 = self.next_step(noise_pred, t, latents) """
+            
+            latents_list.append(latents)
+            pred_x0_list.append(pred_x0)
+
+        if return_intermediates:
+            # return the intermediate laters during inversion
+            #pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+            #latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+            return latents, latents_list, pred_x0_list
+        return latents, start_latents
+    
+    def opt(
+        self,
+        model_output: torch.FloatTensor,
+        timestep: int,
+        x: torch.FloatTensor,
+    ):
+        """
+        predict the sampe the next step in the denoise process.
+        """
+        ref_noise = model_output[:1,:,:,:].expand(model_output.shape)
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+        beta_prod_t = 1 - alpha_prod_t
+        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+        x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t)**0.5 * ref_noise
+        return x_opt, pred_x0
+    
+    def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
+        """
+        Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
+        """
+        def ca_forward(self, place_in_unet):
+            def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+                """
+                The attention is similar to the original implementation of LDM CrossAttention class
+                except adding some modifications on the attention
+                """
+                if encoder_hidden_states is not None:
+                    context = encoder_hidden_states
+                if attention_mask is not None:
+                    mask = attention_mask
+
+                to_out = self.to_out
+                if isinstance(to_out, nn.modules.container.ModuleList):
+                    to_out = self.to_out[0]
+                else:
+                    to_out = self.to_out
+
+                h = self.heads
+                q = self.to_q(x)
+                is_cross = context is not None
+                context = context if is_cross else x
+                k = self.to_k(context)
+                v = self.to_v(context)
+                q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+                sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+                if mask is not None:
+                    mask = rearrange(mask, 'b ... -> b (...)')
+                    max_neg_value = -torch.finfo(sim.dtype).max
+                    mask = repeat(mask, 'b j -> (b h) () j', h=h)
+                    mask = mask[:, None, :].repeat(h, 1, 1)
+                    sim.masked_fill_(~mask, max_neg_value)
+
+                attn = sim.softmax(dim=-1)
+                # the only difference
+                out = editor(
+                    q, k, v, sim, attn, is_cross, place_in_unet,
+                    self.heads, scale=self.scale)
+
+                return to_out(out)
+
+            return forward
+
+        def register_editor(net, count, place_in_unet):
+            for name, subnet in net.named_children():
+                if net.__class__.__name__ == 'Attention':  # spatial Transformer layer
+                    net.forward = ca_forward(net, place_in_unet)
+                    return count + 1
+                elif hasattr(net, 'children'):
+                    count = register_editor(subnet, count, place_in_unet)
+            return count
+
+        cross_att_count = 0
+        for net_name, net in unet.named_children():
+            if "down" in net_name:
+                cross_att_count += register_editor(net, 0, "down")
+            elif "mid" in net_name:
+                cross_att_count += register_editor(net, 0, "mid")
+            elif "up" in net_name:
+                cross_att_count += register_editor(net, 0, "up")
+        editor.num_att_layers = cross_att_count
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        prompt_2: Optional[Union[str, List[str]]] = None,
+        image: PipelineImageInput = None,
+        mask_image: PipelineImageInput = None,
+        masked_image_latents: torch.FloatTensor = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        padding_mask_crop: Optional[int] = None,
+        strength: float = 0.9999,
+        AAS: bool = True,  # AE parameter
+        rm_guidance_scale: float = 7.0, # AE parameter
+        ss_steps: int = 9,  # AE parameter
+        ss_scale: float = 0.3, # AE parameter
+        AAS_start_step: int = 0, # AE parameter
+        AAS_start_layer: int = 34, # AE parameter
+        AAS_end_layer: int = 70, # AE parameter
+        num_inference_steps: int = 50,
+        timesteps: List[int] = None,
+        denoising_start: Optional[float] = None,
+        denoising_end: Optional[float] = None,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt_2: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        guidance_rescale: float = 0.0,
+        original_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Tuple[int, int] = None,
+        negative_original_size: Optional[Tuple[int, int]] = None,
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_target_size: Optional[Tuple[int, int]] = None,
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        **kwargs,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            image (`PIL.Image.Image`):
+                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+                be masked out with `mask_image` and repainted according to `prompt`.
+            mask_image (`PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+                instead of 3, so the expected shape would be `(B, H, W, 1)`.
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image. This is set to 1024 by default for the best results.
+                Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image. This is set to 1024 by default for the best results.
+                Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            padding_mask_crop (`int`, *optional*, defaults to `None`):
+                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+                `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+                contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+                the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+                and contain information inreleant for inpainging, such as background.
+            strength (`float`, *optional*, defaults to 0.9999):
+                Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+                between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+                `strength`. The number of denoising steps depends on the amount of noise initially added. When
+                `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+                iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+                portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+                integer, the value of `strength` will be ignored.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            denoising_start (`float`, *optional*):
+                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+                is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+            denoising_end (`float`, *optional*):
+                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+                completed before it is intentionally prematurely terminated. As a result, the returned sample will
+                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+                forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+                if `do_classifier_free_guidance` is set to `True`.
+                If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+                explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                For most cases, `target_size` should be set to the desired height and width of the generated image. If
+                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a target image resolution. It should be as same
+                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            aesthetic_score (`float`, *optional*, defaults to 6.0):
+                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+                simulate an aesthetic score of the generated image by influencing the negative text condition.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+            )
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 1. Check inputs
+        self.check_inputs(
+            prompt,
+            prompt_2,
+            image,
+            mask_image,
+            height,
+            width,
+            strength,
+            callback_steps,
+            output_type,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds,
+            negative_prompt_embeds,
+            ip_adapter_image,
+            ip_adapter_image_embeds,
+            callback_on_step_end_tensor_inputs,
+            padding_mask_crop,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._guidance_rescale = guidance_rescale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+        self._denoising_end = denoising_end
+        self._denoising_start = denoising_start
+        self._interrupt = False
+
+        ########### AE parameters
+        self._num_timesteps = num_inference_steps
+        self._rm_guidance_scale = rm_guidance_scale
+        self._AAS = AAS
+        self._ss_steps = ss_steps
+        self._ss_scale = ss_scale
+        self._AAS_start_step = AAS_start_step
+        self._AAS_start_layer = AAS_start_layer
+        self._AAS_end_layer = AAS_end_layer
+        ###########
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # 3. Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            prompt_2=prompt_2,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=self.do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            negative_prompt_2=negative_prompt_2,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        # 4. set timesteps
+        def denoising_value_valid(dnv):
+            return isinstance(dnv, float) and 0 < dnv < 1
+
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        timesteps, num_inference_steps = self.get_timesteps(
+            num_inference_steps,
+            strength,
+            device,
+            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+        )
+        # check that number of inference steps is not < 1 - as this doesn't make sense
+        if num_inference_steps < 1:
+            raise ValueError(
+                f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+                f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+            )
+        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+        is_strength_max = strength == 1.0
+
+        # 5. Preprocess mask and image
+        if padding_mask_crop is not None:
+            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+            resize_mode = "fill"
+        else:
+            crops_coords = None
+            resize_mode = "default"
+
+        original_image = image
+        init_image = self.image_processor.preprocess(
+            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+        )
+        init_image = init_image.to(dtype=torch.float32)
+
+        mask = self.mask_processor.preprocess(
+            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+        )
+
+        if masked_image_latents is not None:
+            masked_image = masked_image_latents
+        elif init_image.shape[1] == 4:
+            # if images are in latent space, we can't mask it
+            masked_image = None
+        else:
+            masked_image = init_image * (mask < 0.5)
+
+        # 6. Prepare latent variables
+        num_channels_latents = self.vae.config.latent_channels
+        num_channels_unet = self.unet.config.in_channels
+        return_image_latents = num_channels_unet == 4
+
+        add_noise = True if self.denoising_start is None else False
+        latents_outputs = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+            image=init_image,
+            timestep=latent_timestep,
+            is_strength_max=is_strength_max,
+            add_noise=add_noise,
+            return_noise=True,
+            return_image_latents=return_image_latents,
+        )
+
+        if return_image_latents:
+            latents, noise, image_latents = latents_outputs
+        else:
+            latents, noise = latents_outputs
+
+        # 7. Prepare mask latent variables
+        mask, masked_image_latents = self.prepare_mask_latents(
+            mask,
+            masked_image,
+            batch_size * num_images_per_prompt,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            self.do_classifier_free_guidance,
+        )
+
+        # 8. Check that sizes of mask, masked image and latents match
+        if num_channels_unet == 9:
+            # default case for runwayml/stable-diffusion-inpainting
+            num_channels_mask = mask.shape[1]
+            num_channels_masked_image = masked_image_latents.shape[1]
+            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+                raise ValueError(
+                    f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+                    f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+                    f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+                    f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+                    " `pipeline.unet` or your `mask_image` or `image` input."
+                )
+        elif num_channels_unet != 4:
+            raise ValueError(
+                f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+            )
+        # 8.1 Prepare extra step kwargs.
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        height, width = latents.shape[-2:]
+        height = height * self.vae_scale_factor
+        width = width * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        # 10. Prepare added time ids & embeddings
+        if negative_original_size is None:
+            negative_original_size = original_size
+        if negative_target_size is None:
+            negative_target_size = target_size
+
+        add_text_embeds = pooled_prompt_embeds
+        if self.text_encoder_2 is None:
+            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        else:
+            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+
+        ###########
+        if self.do_self_attention_redirection_guidance:
+            prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+        ############
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device)
+
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+                self.do_classifier_free_guidance,
+            )
+
+        # apply AAS to modify the attention module
+        if self.do_self_attention_redirection_guidance:
+            self._AAS_end_step = int(strength * self._num_timesteps)
+            layer_idx=list(range(self._AAS_start_layer, self._AAS_end_layer))
+            editor = AAS_XL(self._AAS_start_step, self._AAS_end_step, self._AAS_start_layer, self._AAS_end_layer, layer_idx= layer_idx, mask=mask_image,model_type="SDXL",ss_steps=self._ss_steps,ss_scale=self._ss_scale)
+            self.regiter_attention_editor_diffusers(self.unet, editor)
+
+
+        # 11. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        if (
+            self.denoising_end is not None
+            and self.denoising_start is not None
+            and denoising_value_valid(self.denoising_end)
+            and denoising_value_valid(self.denoising_start)
+            and self.denoising_start >= self.denoising_end
+        ):
+            raise ValueError(
+                f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+                + f" {self.denoising_end} when using type float."
+            )
+        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+                )
+            )
+            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+            timesteps = timesteps[:num_inference_steps]
+
+        # 11.1 Optionally get Guidance Scale Embedding
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        self._num_timesteps = len(timesteps)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+                #removal guidance
+                latent_model_input = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+                #latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+                
+                # concat latents, mask, masked_image_latents in the channel dimension
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+                #latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+
+                if num_channels_unet == 9:
+                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+                # predict the noise residual
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+                    added_cond_kwargs["image_embeds"] = image_embeds
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep_cond=timestep_cond,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform SARG
+                if self.do_self_attention_redirection_guidance:
+                    noise_pred_wo, noise_pred_w = noise_pred.chunk(2)
+                    delta = noise_pred_w - noise_pred_wo
+                    noise_pred = noise_pred_wo + self._rm_guidance_scale * delta
+
+                # perform guidance
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if num_channels_unet == 4:
+                    init_latents_proper = image_latents
+                    if self.do_classifier_free_guidance:
+                        init_mask, _ = mask.chunk(2)
+                    else:
+                        init_mask = mask
+
+                    if i < len(timesteps) - 1:
+                        noise_timestep = timesteps[i + 1]
+                        init_latents_proper = self.scheduler.add_noise(
+                            init_latents_proper, noise, torch.tensor([noise_timestep])
+                        )
+
+                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+                    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+                    negative_pooled_prompt_embeds = callback_outputs.pop(
+                        "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+                    )
+                    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+                    add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+                    mask = callback_outputs.pop("mask", mask)
+                    masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+                if XLA_AVAILABLE:
+                    xm.mark_step()
+
+        if not output_type == "latent":
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            latents = latents[-1:]
+
+            if needs_upcasting:
+                self.upcast_vae()
+                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+            # unscale/denormalize the latents
+            # denormalize with the mean and std if available and not None
+            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+            if has_latents_mean and has_latents_std:
+                latents_mean = (
+                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents_std = (
+                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+            else:
+                latents = latents / self.vae.config.scaling_factor
+
+            image = self.vae.decode(latents, return_dict=False)[0]
+
+            # cast back to fp16 if needed
+            if needs_upcasting:
+                self.vae.to(dtype=torch.float16)
+        else:
+            return StableDiffusionXLPipelineOutput(images=latents)
+
+        # apply watermark if available
+        if self.watermark is not None:
+            image = self.watermark.apply_watermark(image)
+
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+        if padding_mask_crop is not None:
+            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
From a2f805d6c9a6881c025e5a9657c68c4717f7aac7 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Tue, 14 Jan 2025 23:22:17 +0800
Subject: [PATCH 2/7] add
 pipeline_stable_diffusion_xl_attentive_eraser_make_style
---
 ...ne_stable_diffusion_xl_attentive_eraser.py | 1014 ++++++++++++-----
 1 file changed, 718 insertions(+), 296 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index cebd80c8bfa1..764a8c1defcb 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -14,52 +14,41 @@
 
 import inspect
 from typing import Any, Callable, Dict, List, Optional, Tuple, Union
-from tqdm import tqdm
+
 import numpy as np
 import PIL.Image
-from PIL import Image
 import torch
-from transformers import (
-    CLIPImageProcessor,
-    CLIPTextModel,
-    CLIPTextModelWithProjection,
-    CLIPTokenizer,
-    CLIPVisionModelWithProjection,
-)
+from PIL import Image
+from transformers import (CLIPImageProcessor, CLIPTextModel,
+                          CLIPTextModelWithProjection, CLIPTokenizer,
+                          CLIPVisionModelWithProjection)
 
 from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
-from diffusers.loaders import (
-    FromSingleFileMixin,
-    IPAdapterMixin,
-    StableDiffusionXLLoraLoaderMixin,
-    TextualInversionLoaderMixin,
-)
-from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
-from diffusers.models.attention_processor import (
-    AttnProcessor2_0,
-    LoRAAttnProcessor2_0,
-    LoRAXFormersAttnProcessor,
-    XFormersAttnProcessor,
-)
+from diffusers.loaders import (FromSingleFileMixin, IPAdapterMixin,
+                               StableDiffusionXLLoraLoaderMixin,
+                               TextualInversionLoaderMixin)
+from diffusers.models import (AutoencoderKL, ImageProjection,
+                              UNet2DConditionModel)
+from diffusers.models.attention_processor import (AttnProcessor2_0,
+                                                  LoRAAttnProcessor2_0,
+                                                  LoRAXFormersAttnProcessor,
+                                                  XFormersAttnProcessor)
 from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import (DiffusionPipeline,
+                                                StableDiffusionMixin)
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import \
+    StableDiffusionXLPipelineOutput
 from diffusers.schedulers import KarrasDiffusionSchedulers
-from diffusers.utils import (
-    USE_PEFT_BACKEND,
-    deprecate,
-    is_invisible_watermark_available,
-    is_torch_xla_available,
-    logging,
-    replace_example_docstring,
-    scale_lora_layers,
-    unscale_lora_layers,
-)
+from diffusers.utils import (USE_PEFT_BACKEND, deprecate,
+                             is_invisible_watermark_available,
+                             is_torch_xla_available, logging,
+                             replace_example_docstring, scale_lora_layers,
+                             unscale_lora_layers)
 from diffusers.utils.torch_utils import randn_tensor
-from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
-from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
-
 
 if is_invisible_watermark_available():
-    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+    from diffusers.pipelines.stable_diffusion_xl.watermark import \
+        StableDiffusionXLWatermarker
 
 if is_torch_xla_available():
     import torch_xla.core.xla_model as xm
@@ -68,10 +57,10 @@
 else:
     XLA_AVAILABLE = False
 
-from einops import rearrange, repeat
+
 import torch.nn as nn
 import torch.nn.functional as F
-import os
+from einops import rearrange, repeat
 
 logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
 
@@ -104,6 +93,7 @@
         ```
 """
 
+
 class AttentionBase:
     def __init__(self):
         self.cur_step = 0
@@ -113,8 +103,12 @@ def __init__(self):
     def after_step(self):
         pass
 
-    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
-        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+    def __call__(
+        self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs
+    ):
+        out = self.forward(
+            q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs
+        )
         self.cur_att_layer += 1
         if self.cur_att_layer == self.num_att_layers:
             self.cur_att_layer = 0
@@ -123,21 +117,34 @@ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwa
             self.after_step()
         return out
 
-    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
-        out = torch.einsum('b i j, b j d -> b i d', attn, v)
-        out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
+    def forward(self, q, k, v, sim, attn, is_cross,
+                place_in_unet, num_heads, **kwargs):
+        out = torch.einsum("b i j, b j d -> b i d", attn, v)
+        out = rearrange(out, "(b h) n d -> b n (h d)", h=num_heads)
         return out
 
     def reset(self):
         self.cur_step = 0
         self.cur_att_layer = 0
 
+
 class AAS_XL(AttentionBase):
-    MODEL_TYPE = {
-        "SD": 16,
-        "SDXL": 70
-    }
-    def __init__(self,  start_step=4, end_step= 50, start_layer=10, end_layer=16,layer_idx=None, step_idx=None, total_steps=50,  mask=None, model_type="SD",ss_steps=9,ss_scale=1.0):
+    MODEL_TYPE = {"SD": 16, "SDXL": 70}
+
+    def __init__(
+        self,
+        start_step=4,
+        end_step=50,
+        start_layer=10,
+        end_layer=16,
+        layer_idx=None,
+        step_idx=None,
+        total_steps=50,
+        mask=None,
+        model_type="SD",
+        ss_steps=9,
+        ss_scale=1.0,
+    ):
         """
         Args:
             start_step: the step to start AAS
@@ -155,49 +162,92 @@ def __init__(self,  start_step=4, end_step= 50, start_layer=10, end_layer=16,lay
         self.end_step = end_step
         self.start_layer = start_layer
         self.end_layer = end_layer
-        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
-        self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+        self.layer_idx = (
+            layer_idx if layer_idx is not None else list(
+                range(start_layer, end_layer))
+        )
+        self.step_idx = (
+            step_idx if step_idx is not None else list(
+                range(start_step, end_step))
+        )
         self.mask = mask  # mask with shape (1, 1 ,h, w)
         self.ss_steps = ss_steps
         self.ss_scale = ss_scale
         print("AAS at denoising steps: ", self.step_idx)
         print("AAS at U-Net layers: ", self.layer_idx)
         print("start AAS")
-        self.mask_16 = F.max_pool2d(mask,(1024//16,1024//16)).round().squeeze().squeeze()
-        self.mask_32 = F.max_pool2d(mask,(1024//32,1024//32)).round().squeeze().squeeze()
-        self.mask_64 = F.max_pool2d(mask,(1024//64,1024//64)).round().squeeze().squeeze()
-        self.mask_128 = F.max_pool2d(mask,(1024//128,1024//128)).round().squeeze().squeeze()
-    
-    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads,is_mask_attn, mask, **kwargs):
+        self.mask_16 = (
+            F.max_pool2d(mask, (1024 // 16, 1024 // 16)
+                         ).round().squeeze().squeeze()
+        )
+        self.mask_32 = (
+            F.max_pool2d(mask, (1024 // 32, 1024 // 32)
+                         ).round().squeeze().squeeze()
+        )
+        self.mask_64 = (
+            F.max_pool2d(mask, (1024 // 64, 1024 // 64)
+                         ).round().squeeze().squeeze()
+        )
+        self.mask_128 = (
+            F.max_pool2d(mask, (1024 // 128, 1024 // 128)
+                         ).round().squeeze().squeeze()
+        )
+
+    def attn_batch(
+        self,
+        q,
+        k,
+        v,
+        sim,
+        attn,
+        is_cross,
+        place_in_unet,
+        num_heads,
+        is_mask_attn,
+        mask,
+        **kwargs,
+    ):
         B = q.shape[0] // num_heads
         if is_mask_attn:
             mask_flatten = mask.flatten(0)
-            if self.cur_step <= self.ss_steps:                                                                                                                                                                                                                                           
+            if self.cur_step <= self.ss_steps:
                 # background
-                sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min) 
+                sim_bg = sim + mask_flatten.masked_fill(
+                    mask_flatten == 1, torch.finfo(sim.dtype).min
+                )
 
-                #object
-                sim_fg = self.ss_scale*sim
-                sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+                # object
+                sim_fg = self.ss_scale * sim
+                sim_fg += mask_flatten.masked_fill(
+                    mask_flatten == 1, torch.finfo(sim.dtype).min
+                )
                 sim = torch.cat([sim_fg, sim_bg], dim=0)
             else:
-                sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+                sim += mask_flatten.masked_fill(
+                    mask_flatten == 1, torch.finfo(sim.dtype).min
+                )
 
         attn = sim.softmax(-1)
         if len(attn) == 2 * len(v):
             v = torch.cat([v] * 2)
         out = torch.einsum("h i j, h j d -> h i d", attn, v)
-        out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+        out = rearrange(
+            out,
+            "(h1 h) (b n) d -> (h1 b) n (h d)",
+            b=B,
+            h=num_heads)
         return out
 
-    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+    def forward(self, q, k, v, sim, attn, is_cross,
+                place_in_unet, num_heads, **kwargs):
         """
         Attention forward function
         """
         if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
             return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
-        B = q.shape[0] // num_heads // 2
-        H = W = int(np.sqrt(q.shape[1]))
+        # B = q.shape[0] // num_heads // 2
+        # H = W = int(np.sqrt(q.shape[1]))
+        H = int(np.sqrt(q.shape[1]))
         if H == 16:
             mask = self.mask_16.to(sim.device)
         elif H == 32:
@@ -207,15 +257,38 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
         else:
             mask = self.mask_128.to(sim.device)
 
-
         q_wo, q_w = q.chunk(2)
         k_wo, k_w = k.chunk(2)
         v_wo, v_w = v.chunk(2)
         sim_wo, sim_w = sim.chunk(2)
         attn_wo, attn_w = attn.chunk(2)
 
-        out_source = self.attn_batch(q_wo, k_wo, v_wo, sim_wo, attn_wo, is_cross, place_in_unet, num_heads,is_mask_attn=False,mask=None,**kwargs)
-        out_target = self.attn_batch(q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask = mask,**kwargs)
+        out_source = self.attn_batch(
+            q_wo,
+            k_wo,
+            v_wo,
+            sim_wo,
+            attn_wo,
+            is_cross,
+            place_in_unet,
+            num_heads,
+            is_mask_attn=False,
+            mask=None,
+            **kwargs,
+        )
+        out_target = self.attn_batch(
+            q_w,
+            k_w,
+            v_w,
+            sim_w,
+            attn_w,
+            is_cross,
+            place_in_unet,
+            num_heads,
+            is_mask_attn=True,
+            mask=mask,
+            **kwargs,
+        )
 
         if self.mask is not None:
             if out_target.shape[0] == 2:
@@ -224,11 +297,13 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
                 out_target = out_target_fg * mask + out_target_bg * (1 - mask)
             else:
                 out_target = out_target
-        
+
         out = torch.cat([out_source, out_target], dim=0)
         return out
-    
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+
+
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
 def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
     """
     Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -259,7 +334,9 @@ def mask_pil_to_torch(mask, height, width):
     return mask
 
 
-def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+def prepare_mask_and_masked_image(
+    image, mask, height, width, return_image: bool = False
+):
     """
     Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
     converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
@@ -314,7 +391,8 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
 
         # Batch single mask or add channel dim
         if mask.ndim == 3:
-            # Single batched mask, no channel dim or single mask not batched but channel dim
+            # Single batched mask, no channel dim or single mask not batched
+            # but channel dim
             if mask.shape[0] == 1:
                 mask = mask.unsqueeze(0)
 
@@ -322,9 +400,13 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
             else:
                 mask = mask.unsqueeze(1)
 
-        assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+        assert (
+            image.ndim == 4 and mask.ndim == 4
+        ), "Image and Mask must have 4 dimensions"
         # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
-        assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+        assert (
+            image.shape[0] == mask.shape[0]
+        ), "Image and Mask must have the same batch size"
 
         # Check image is in [-1, 1]
         # if image.min() < -1 or image.max() > 1:
@@ -341,14 +423,18 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
         # Image as float32
         image = image.to(dtype=torch.float32)
     elif isinstance(mask, torch.Tensor):
-        raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+        raise TypeError(
+            f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
+        )
     else:
         # preprocess image
         if isinstance(image, (PIL.Image.Image, np.ndarray)):
             image = [image]
         if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
             # resize all images w.r.t passed height an width
-            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+            image = [
+                i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image
+            ]
             image = [np.array(i.convert("RGB"))[None, :] for i in image]
             image = np.concatenate(image, axis=0)
         elif isinstance(image, list) and isinstance(image[0], np.ndarray):
@@ -377,9 +463,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
     return mask, masked_image
 
 
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
 def retrieve_latents(
-    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+    encoder_output: torch.Tensor,
+    generator: Optional[torch.Generator] = None,
+    sample_mode: str = "sample",
 ):
     if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
         return encoder_output.latent_dist.sample(generator)
@@ -388,10 +477,12 @@ def retrieve_latents(
     elif hasattr(encoder_output, "latents"):
         return encoder_output.latents
     else:
-        raise AttributeError("Could not access latents of provided encoder_output")
+        raise AttributeError(
+            "Could not access latents of provided encoder_output")
 
 
-# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+# Copied from
+# diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
 def retrieve_timesteps(
     scheduler,
     num_inference_steps: Optional[int] = None,
@@ -421,7 +512,9 @@ def retrieve_timesteps(
         second element is the number of inference steps.
     """
     if timesteps is not None:
-        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        accepts_timesteps = "timesteps" in set(
+            inspect.signature(scheduler.set_timesteps).parameters.keys()
+        )
         if not accepts_timesteps:
             raise ValueError(
                 f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@@ -542,55 +635,86 @@ def __init__(
             feature_extractor=feature_extractor,
             scheduler=scheduler,
         )
-        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
-        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
-        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
-        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.register_to_config(
+            force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
+        )
+        self.register_to_config(
+            requires_aesthetics_score=requires_aesthetics_score)
+        self.vae_scale_factor = 2 ** (
+            len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor)
         self.mask_processor = VaeImageProcessor(
-            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+            vae_scale_factor=self.vae_scale_factor,
+            do_normalize=False,
+            do_binarize=True,
+            do_convert_grayscale=True,
         )
 
-        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+        add_watermarker = (
+            add_watermarker
+            if add_watermarker is not None
+            else is_invisible_watermark_available()
+        )
 
         if add_watermarker:
             self.watermark = StableDiffusionXLWatermarker()
         else:
             self.watermark = None
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
-    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+    # Copied from
+    # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(
+        self, image, device, num_images_per_prompt, output_hidden_states=None
+    ):
         dtype = next(self.image_encoder.parameters()).dtype
 
         if not isinstance(image, torch.Tensor):
-            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+            image = self.feature_extractor(
+                image, return_tensors="pt").pixel_values
 
         image = image.to(device=device, dtype=dtype)
         if output_hidden_states:
-            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
-            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            image_enc_hidden_states = self.image_encoder(
+                image, output_hidden_states=True
+            ).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
             uncond_image_enc_hidden_states = self.image_encoder(
                 torch.zeros_like(image), output_hidden_states=True
             ).hidden_states[-2]
-            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
-                num_images_per_prompt, dim=0
+            uncond_image_enc_hidden_states = (
+                uncond_image_enc_hidden_states.repeat_interleave(
+                    num_images_per_prompt, dim=0
+                )
             )
             return image_enc_hidden_states, uncond_image_enc_hidden_states
         else:
             image_embeds = self.image_encoder(image).image_embeds
-            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            image_embeds = image_embeds.repeat_interleave(
+                num_images_per_prompt, dim=0)
             uncond_image_embeds = torch.zeros_like(image_embeds)
 
             return image_embeds, uncond_image_embeds
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+    # Copied from
+    # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
     def prepare_ip_adapter_image_embeds(
-        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+        self,
+        ip_adapter_image,
+        ip_adapter_image_embeds,
+        device,
+        num_images_per_prompt,
+        do_classifier_free_guidance,
     ):
         if ip_adapter_image_embeds is None:
             if not isinstance(ip_adapter_image, list):
                 ip_adapter_image = [ip_adapter_image]
 
-            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+            if len(ip_adapter_image) != len(
+                self.unet.encoder_hid_proj.image_projection_layers
+            ):
                 raise ValueError(
                     f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
                 )
@@ -634,7 +758,8 @@ def prepare_ip_adapter_image_embeds(
 
         return image_embeds
 
-    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+    # Copied from
+    # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
     def encode_prompt(
         self,
         prompt: str,
@@ -697,19 +822,23 @@ def encode_prompt(
 
         # set lora scale so that monkey patched LoRA
         # function of text encoder can correctly access it
-        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+        if lora_scale is not None and isinstance(
+            self, StableDiffusionXLLoraLoaderMixin
+        ):
             self._lora_scale = lora_scale
 
             # dynamically adjust the LoRA scale
             if self.text_encoder is not None:
                 if not USE_PEFT_BACKEND:
-                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+                    adjust_lora_scale_text_encoder(
+                        self.text_encoder, lora_scale)
                 else:
                     scale_lora_layers(self.text_encoder, lora_scale)
 
             if self.text_encoder_2 is not None:
                 if not USE_PEFT_BACKEND:
-                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+                    adjust_lora_scale_text_encoder(
+                        self.text_encoder_2, lora_scale)
                 else:
                     scale_lora_layers(self.text_encoder_2, lora_scale)
 
@@ -721,9 +850,15 @@ def encode_prompt(
             batch_size = prompt_embeds.shape[0]
 
         # Define tokenizers and text encoders
-        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        tokenizers = (
+            [self.tokenizer, self.tokenizer_2]
+            if self.tokenizer is not None
+            else [self.tokenizer_2]
+        )
         text_encoders = (
-            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+            [self.text_encoder, self.text_encoder_2]
+            if self.text_encoder is not None
+            else [self.text_encoder_2]
         )
 
         if prompt_embeds is None:
@@ -733,7 +868,9 @@ def encode_prompt(
             # textual inversion: process multi-vector tokens if necessary
             prompt_embeds_list = []
             prompts = [prompt, prompt_2]
-            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+            for prompt, tokenizer, text_encoder in zip(
+                prompts, tokenizers, text_encoders
+            ):
                 if isinstance(self, TextualInversionLoaderMixin):
                     prompt = self.maybe_convert_prompt(prompt, tokenizer)
 
@@ -746,26 +883,34 @@ def encode_prompt(
                 )
 
                 text_input_ids = text_inputs.input_ids
-                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
-
-                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
-                    text_input_ids, untruncated_ids
-                ):
-                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                untruncated_ids = tokenizer(
+                    prompt, padding="longest", return_tensors="pt"
+                ).input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[
+                    -1
+                ] and not torch.equal(text_input_ids, untruncated_ids):
+                    removed_text = tokenizer.batch_decode(
+                        untruncated_ids[:, tokenizer.model_max_length - 1: -1]
+                    )
                     logger.warning(
                         "The following part of your input was truncated because CLIP can only handle sequences up to"
                         f" {tokenizer.model_max_length} tokens: {removed_text}"
                     )
 
-                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+                prompt_embeds = text_encoder(
+                    text_input_ids.to(device), output_hidden_states=True
+                )
 
-                # We are only ALWAYS interested in the pooled output of the final text encoder
+                # We are only ALWAYS interested in the pooled output of the
+                # final text encoder
                 pooled_prompt_embeds = prompt_embeds[0]
                 if clip_skip is None:
                     prompt_embeds = prompt_embeds.hidden_states[-2]
                 else:
                     # "2" because SDXL always indexes from the penultimate layer.
-                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+                    prompt_embeds = prompt_embeds.hidden_states[-(
+                        clip_skip + 2)]
 
                 prompt_embeds_list.append(prompt_embeds)
 
@@ -819,70 +964,99 @@ def encode_prompt(
                     uncond_input.input_ids.to(device),
                     output_hidden_states=True,
                 )
-                # We are only ALWAYS interested in the pooled output of the final text encoder
+                # We are only ALWAYS interested in the pooled output of the
+                # final text encoder
                 negative_pooled_prompt_embeds = negative_prompt_embeds[0]
                 negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
 
                 negative_prompt_embeds_list.append(negative_prompt_embeds)
 
-            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+            negative_prompt_embeds = torch.concat(
+                negative_prompt_embeds_list, dim=-1)
 
         if self.text_encoder_2 is not None:
-            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+            prompt_embeds = prompt_embeds.to(
+                dtype=self.text_encoder_2.dtype, device=device
+            )
         else:
-            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+            prompt_embeds = prompt_embeds.to(
+                dtype=self.unet.dtype, device=device)
 
         bs_embed, seq_len, _ = prompt_embeds.shape
-        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        # duplicate text embeddings for each generation per prompt, using mps
+        # friendly method
         prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
-        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+        prompt_embeds = prompt_embeds.view(
+            bs_embed * num_images_per_prompt, seq_len, -1
+        )
 
         if do_classifier_free_guidance:
-            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            # duplicate unconditional embeddings for each generation per
+            # prompt, using mps friendly method
             seq_len = negative_prompt_embeds.shape[1]
 
             if self.text_encoder_2 is not None:
-                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+                negative_prompt_embeds = negative_prompt_embeds.to(
+                    dtype=self.text_encoder_2.dtype, device=device
+                )
             else:
-                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+                negative_prompt_embeds = negative_prompt_embeds.to(
+                    dtype=self.unet.dtype, device=device
+                )
 
-            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
-            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+            negative_prompt_embeds = negative_prompt_embeds.repeat(
+                1, num_images_per_prompt, 1
+            )
+            negative_prompt_embeds = negative_prompt_embeds.view(
+                batch_size * num_images_per_prompt, seq_len, -1
+            )
 
-        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
-            bs_embed * num_images_per_prompt, -1
-        )
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(
+            1, num_images_per_prompt
+        ).view(bs_embed * num_images_per_prompt, -1)
         if do_classifier_free_guidance:
-            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
-                bs_embed * num_images_per_prompt, -1
-            )
+            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
+                1, num_images_per_prompt
+            ).view(bs_embed * num_images_per_prompt, -1)
 
         if self.text_encoder is not None:
-            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+            if isinstance(
+                    self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
                 # Retrieve the original scale by scaling back the LoRA layers
                 unscale_lora_layers(self.text_encoder, lora_scale)
 
         if self.text_encoder_2 is not None:
-            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+            if isinstance(
+                    self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
                 # Retrieve the original scale by scaling back the LoRA layers
                 unscale_lora_layers(self.text_encoder_2, lora_scale)
 
-        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+        return (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        )
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    # Copied from
+    # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
     def prepare_extra_step_kwargs(self, generator, eta):
         # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         # and should be between [0, 1]
 
-        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        accepts_eta = "eta" in set(
+            inspect.signature(self.scheduler.step).parameters.keys()
+        )
         extra_step_kwargs = {}
         if accepts_eta:
             extra_step_kwargs["eta"] = eta
 
         # check if the scheduler accepts generator
-        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        accepts_generator = "generator" in set(
+            inspect.signature(self.scheduler.step).parameters.keys()
+        )
         if accepts_generator:
             extra_step_kwargs["generator"] = generator
         return extra_step_kwargs
@@ -908,19 +1082,26 @@ def check_inputs(
         padding_mask_crop=None,
     ):
         if strength < 0 or strength > 1:
-            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+            raise ValueError(
+                f"The value of strength should in [0.0, 1.0] but is {strength}"
+            )
 
         if height % 8 != 0 or width % 8 != 0:
-            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+            raise ValueError(
+                f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+            )
 
-        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+        if callback_steps is not None and (
+            not isinstance(callback_steps, int) or callback_steps <= 0
+        ):
             raise ValueError(
                 f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
                 f" {type(callback_steps)}."
             )
 
         if callback_on_step_end_tensor_inputs is not None and not all(
-            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+            k in self._callback_tensor_inputs
+            for k in callback_on_step_end_tensor_inputs
         ):
             raise ValueError(
                 f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@@ -940,10 +1121,18 @@ def check_inputs(
             raise ValueError(
                 "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
             )
-        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
-            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
-        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
-            raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+        elif prompt is not None and (
+            not isinstance(prompt, str) and not isinstance(prompt, list)
+        ):
+            raise ValueError(
+                f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+            )
+        elif prompt_2 is not None and (
+            not isinstance(prompt_2, str) and not isinstance(prompt_2, list)
+        ):
+            raise ValueError(
+                f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}"
+            )
 
         if negative_prompt is not None and negative_prompt_embeds is not None:
             raise ValueError(
@@ -966,7 +1155,8 @@ def check_inputs(
         if padding_mask_crop is not None:
             if not isinstance(image, PIL.Image.Image):
                 raise ValueError(
-                    f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+                    f"The image should be a PIL image when inpainting mask crop, but is of type"
+                    f" {type(image)}."
                 )
             if not isinstance(mask_image, PIL.Image.Image):
                 raise ValueError(
@@ -974,7 +1164,10 @@ def check_inputs(
                     f" {type(mask_image)}."
                 )
             if output_type != "pil":
-                raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+                raise ValueError(
+                    f"The output type should be PIL when inpainting mask crop, but is"
+                    f" {output_type}."
+                )
 
         if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
             raise ValueError(
@@ -1008,7 +1201,12 @@ def prepare_latents(
         return_noise=False,
         return_image_latents=False,
     ):
-        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        shape = (
+            batch_size,
+            num_channels_latents,
+            height // self.vae_scale_factor,
+            width // self.vae_scale_factor,
+        )
         if isinstance(generator, list) and len(generator) != batch_size:
             raise ValueError(
                 f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@@ -1023,23 +1221,46 @@ def prepare_latents(
 
         if image.shape[1] == 4:
             image_latents = image.to(device=device, dtype=dtype)
-            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+            image_latents = image_latents.repeat(
+                batch_size // image_latents.shape[0], 1, 1, 1
+            )
         elif return_image_latents or (latents is None and not is_strength_max):
             image = image.to(device=device, dtype=dtype)
-            image_latents = self._encode_vae_image(image=image, generator=generator)
-            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+            image_latents = self._encode_vae_image(
+                image=image, generator=generator)
+            image_latents = image_latents.repeat(
+                batch_size // image_latents.shape[0], 1, 1, 1
+            )
 
         if latents is None and add_noise:
-            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
-            # if strength is 1. then initialise the latents to noise, else initial to image + noise
-            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
-            # if pure noise then scale the initial latents by the  Scheduler's init sigma
-            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+            noise = randn_tensor(
+                shape,
+                generator=generator,
+                device=device,
+                dtype=dtype)
+            # if strength is 1. then initialise the latents to noise, else
+            # initial to image + noise
+            latents = (
+                noise
+                if is_strength_max
+                else self.scheduler.add_noise(image_latents, noise, timestep)
+            )
+            # if pure noise then scale the initial latents by the  Scheduler's
+            # init sigma
+            latents = (
+                latents * self.scheduler.init_noise_sigma
+                if is_strength_max
+                else latents
+            )
         elif add_noise:
             noise = latents.to(device)
             latents = noise * self.scheduler.init_noise_sigma
         else:
-            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            noise = randn_tensor(
+                shape,
+                generator=generator,
+                device=device,
+                dtype=dtype)
             latents = image_latents.to(device)
 
         outputs = (latents,)
@@ -1052,7 +1273,8 @@ def prepare_latents(
 
         return outputs
 
-    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+    def _encode_vae_image(self, image: torch.Tensor,
+                          generator: torch.Generator):
         dtype = image.dtype
         if self.vae.config.force_upcast:
             image = image.float()
@@ -1060,12 +1282,16 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
 
         if isinstance(generator, list):
             image_latents = [
-                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+                retrieve_latents(
+                    self.vae.encode(image[i: i + 1]), generator=generator[i]
+                )
                 for i in range(image.shape[0])
             ]
             image_latents = torch.cat(image_latents, dim=0)
         else:
-            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+            image_latents = retrieve_latents(
+                self.vae.encode(image), generator=generator
+            )
 
         if self.vae.config.force_upcast:
             self.vae.to(dtype)
@@ -1076,18 +1302,28 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
         return image_latents
 
     def prepare_mask_latents(
-        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+        self,
+        mask,
+        masked_image,
+        batch_size,
+        height,
+        width,
+        dtype,
+        device,
+        generator,
+        do_classifier_free_guidance,
     ):
         # resize the mask to latents shape as we concatenate the mask to the latents
         # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
         # and half precision
-        #mask = torch.nn.functional.interpolate(
+        # mask = torch.nn.functional.interpolate(
         #    mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
-        #)
-        mask = torch.nn.functional.max_pool2d(mask, (8,8)).round()
+        # )
+        mask = torch.nn.functional.max_pool2d(mask, (8, 8)).round()
         mask = mask.to(device=device, dtype=dtype)
 
-        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+        # duplicate mask and masked_image_latents for each generation per
+        # prompt, using mps friendly method
         if mask.shape[0] < batch_size:
             if not batch_size % mask.shape[0] == 0:
                 raise ValueError(
@@ -1107,7 +1343,9 @@ def prepare_mask_latents(
         if masked_image is not None:
             if masked_image_latents is None:
                 masked_image = masked_image.to(device=device, dtype=dtype)
-                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+                masked_image_latents = self._encode_vae_image(
+                    masked_image, generator=generator
+                )
 
             if masked_image_latents.shape[0] < batch_size:
                 if not batch_size % masked_image_latents.shape[0] == 0:
@@ -1121,16 +1359,23 @@ def prepare_mask_latents(
                 )
 
             masked_image_latents = (
-                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+                torch.cat([masked_image_latents] * 2)
+                if do_classifier_free_guidance
+                else masked_image_latents
             )
 
-            # aligning device to prevent device errors when concating it with the latent model input
-            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+            # aligning device to prevent device errors when concating it with
+            # the latent model input
+            masked_image_latents = masked_image_latents.to(
+                device=device, dtype=dtype)
 
         return mask, masked_image_latents
 
-    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
-    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+    # Copied from
+    # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+    def get_timesteps(
+        self, num_inference_steps, strength, device, denoising_start=None
+    ):
         # get the original timestep using init_timestep
         if denoising_start is None:
             init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
@@ -1138,7 +1383,7 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
         else:
             t_start = 0
 
-        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
 
         # Strength is irrelevant if we directly request a timestep to start at;
         # that is, strength is determined by the denoising_start instead.
@@ -1157,16 +1402,19 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N
                 # (except the highest one) is duplicated. If `num_inference_steps` is even it would
                 # mean that we cut the timesteps in the middle of the denoising step
                 # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
-                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+                # we ensure that the denoising process always ends after the
+                # 2nd derivate step of the scheduler
                 num_inference_steps = num_inference_steps + 1
 
-            # because t_n+1 >= t_n, we slice the timesteps starting from the end
+            # because t_n+1 >= t_n, we slice the timesteps starting from the
+            # end
             timesteps = timesteps[-num_inference_steps:]
             return timesteps, num_inference_steps
 
         return timesteps, num_inference_steps - t_start
 
-    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+    # Copied from
+    # diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
     def _get_add_time_ids(
         self,
         original_size,
@@ -1218,7 +1466,8 @@ def _get_add_time_ids(
 
         return add_time_ids, add_neg_time_ids
 
-    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+    # Copied from
+    # diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
     def upcast_vae(self):
         dtype = self.vae.dtype
         self.vae.to(dtype=torch.float32)
@@ -1238,8 +1487,10 @@ def upcast_vae(self):
             self.vae.decoder.conv_in.to(dtype)
             self.vae.decoder.mid_block.to(dtype)
 
-    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
-    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+    # Copied from
+    # diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+    def get_guidance_scale_embedding(
+            self, w, embedding_dim=512, dtype=torch.float32):
         """
         See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
 
@@ -1279,18 +1530,21 @@ def guidance_rescale(self):
     def clip_skip(self):
         return self._clip_skip
 
-
     @property
-    def do_self_attention_redirection_guidance(self): #SARG
+    def do_self_attention_redirection_guidance(self):  # SARG
         return self._rm_guidance_scale > 1 and self._AAS
-    
+
     # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
     # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
     # corresponds to doing no classifier free guidance.
     @property
     def do_classifier_free_guidance(self):
-        return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None and self.do_self_attention_redirection_guidance==False #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
-    
+        return (
+            self._guidance_scale > 1
+            and self.unet.config.time_cond_proj_dim is None
+            and not self.do_self_attention_redirection_guidance
+        )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+
     @property
     def cross_attention_kwargs(self):
         return self._cross_attention_kwargs
@@ -1313,25 +1567,28 @@ def interrupt(self):
 
     @torch.no_grad()
     def image2latent(self, image: torch.Tensor, generator: torch.Generator):
-        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        DEVICE = (
+            torch.device("cuda") if torch.cuda.is_available(
+            ) else torch.device("cpu")
+        )
         if type(image) is Image:
             image = np.array(image)
             image = torch.from_numpy(image).float() / 127.5 - 1
             image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
         # input image density range [-1, 1]
-        #latents = self.vae.encode(image)['latent_dist'].mean
+        # latents = self.vae.encode(image)['latent_dist'].mean
         latents = self._encode_vae_image(image, generator)
-        #latents = retrieve_latents(self.vae.encode(image))
-        #latents = latents * self.vae.config.scaling_factor
+        # latents = retrieve_latents(self.vae.encode(image))
+        # latents = latents * self.vae.config.scaling_factor
         return latents
-    
+
     def next_step(
         self,
         model_output: torch.FloatTensor,
         timestep: int,
         x: torch.FloatTensor,
-        eta=0.,
-        verbose=False
+        eta=0.0,
+        verbose=False,
     ):
         """
         Inverse sampling for DDIM Inversion
@@ -1339,15 +1596,24 @@ def next_step(
         if verbose:
             print("timestep: ", timestep)
         next_step = timestep
-        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
-        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+        timestep = min(
+            timestep
+            - self.scheduler.config.num_train_timesteps
+            // self.scheduler.num_inference_steps,
+            999,
+        )
+        alpha_prod_t = (
+            self.scheduler.alphas_cumprod[timestep]
+            if timestep >= 0
+            else self.scheduler.final_alpha_cumprod
+        )
         alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
         beta_prod_t = 1 - alpha_prod_t
         pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
-        pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
+        pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output
         x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
         return x_next, pred_x0
-    
+
     @torch.no_grad()
     def invert(
         self,
@@ -1362,11 +1628,15 @@ def invert(
         aesthetic_score: float = 6.0,
         negative_aesthetic_score: float = 2.5,
         return_intermediates=False,
-        **kwds):
+        **kwds,
+    ):
         """
         invert a real image into noise map with determinisc DDIM inversion
         """
-        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        DEVICE = (
+            torch.device("cuda") if torch.cuda.is_available(
+            ) else torch.device("cpu")
+        )
         batch_size = image.shape[0]
         if isinstance(prompt, list):
             if batch_size == 1:
@@ -1376,9 +1646,15 @@ def invert(
                 prompt = [prompt] * batch_size
 
         # Define tokenizers and text encoders
-        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        tokenizers = (
+            [self.tokenizer, self.tokenizer_2]
+            if self.tokenizer is not None
+            else [self.tokenizer_2]
+        )
         text_encoders = (
-            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+            [self.text_encoder, self.text_encoder_2]
+            if self.text_encoder is not None
+            else [self.text_encoder_2]
         )
 
         prompt_2 = prompt
@@ -1387,7 +1663,8 @@ def invert(
         # textual inversion: process multi-vector tokens if necessary
         prompt_embeds_list = []
         prompts = [prompt, prompt_2]
-        for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+        for prompt, tokenizer, text_encoder in zip(
+                prompts, tokenizers, text_encoders):
             if isinstance(self, TextualInversionLoaderMixin):
                 prompt = self.maybe_convert_prompt(prompt, tokenizer)
 
@@ -1400,20 +1677,27 @@ def invert(
             )
 
             text_input_ids = text_inputs.input_ids
-            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
-
-            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
-                text_input_ids, untruncated_ids
-            ):
-                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+            untruncated_ids = tokenizer(
+                prompt, padding="longest", return_tensors="pt"
+            ).input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[
+                -1
+            ] and not torch.equal(text_input_ids, untruncated_ids):
+                removed_text = tokenizer.batch_decode(
+                    untruncated_ids[:, tokenizer.model_max_length - 1: -1]
+                )
                 logger.warning(
                     "The following part of your input was truncated because CLIP can only handle sequences up to"
                     f" {tokenizer.model_max_length} tokens: {removed_text}"
                 )
 
-            prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+            prompt_embeds = text_encoder(
+                text_input_ids.to(DEVICE), output_hidden_states=True
+            )
 
-            # We are only ALWAYS interested in the pooled output of the final text encoder
+            # We are only ALWAYS interested in the pooled output of the final
+            # text encoder
             pooled_prompt_embeds = prompt_embeds[0]
             prompt_embeds = prompt_embeds.hidden_states[-2]
             prompt_embeds_list.append(prompt_embeds)
@@ -1421,7 +1705,7 @@ def invert(
         prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
 
         # define initial latents
-        latents = self.image2latent(image,generator=None)
+        latents = self.image2latent(image, generator=None)
 
         start_latents = latents
         height, width = latents.shape[-2:]
@@ -1454,32 +1738,41 @@ def invert(
         self.scheduler.set_timesteps(num_inference_steps)
         latents_list = [latents]
         pred_x0_list = []
-        #for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+        # for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps),
+        # desc="DDIM Inversion")):
         for i, t in enumerate(reversed(self.scheduler.timesteps)):
             model_inputs = latents
 
             # predict the noise
-            added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
-            noise_pred = self.unet(model_inputs, t, encoder_hidden_states=prompt_embeds,added_cond_kwargs=added_cond_kwargs).sample
+            added_cond_kwargs = {
+                "text_embeds": add_text_embeds,
+                "time_ids": add_time_ids,
+            }
+            noise_pred = self.unet(
+                model_inputs,
+                t,
+                encoder_hidden_states=prompt_embeds,
+                added_cond_kwargs=added_cond_kwargs,
+            ).sample
 
             # compute the previous noise sample x_t-1 -> x_t
             latents, pred_x0 = self.next_step(noise_pred, t, latents)
-            """             
+            """
             if t >= 1 and t < 41:
                 latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
             else:
                 latents, pred_x0 = self.next_step(noise_pred, t, latents) """
-            
+
             latents_list.append(latents)
             pred_x0_list.append(pred_x0)
 
         if return_intermediates:
             # return the intermediate laters during inversion
-            #pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
-            #latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+            # pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+            # latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
             return latents, latents_list, pred_x0_list
         return latents, start_latents
-    
+
     def opt(
         self,
         model_output: torch.FloatTensor,
@@ -1489,19 +1782,27 @@ def opt(
         """
         predict the sampe the next step in the denoise process.
         """
-        ref_noise = model_output[:1,:,:,:].expand(model_output.shape)
+        ref_noise = model_output[:1, :, :, :].expand(model_output.shape)
         alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
         beta_prod_t = 1 - alpha_prod_t
         pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
-        x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t)**0.5 * ref_noise
+        x_opt = alpha_prod_t**0.5 * pred_x0 + \
+            (1 - alpha_prod_t) ** 0.5 * ref_noise
         return x_opt, pred_x0
-    
+
     def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
         """
         Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
         """
+
         def ca_forward(self, place_in_unet):
-            def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+            def forward(
+                x,
+                encoder_hidden_states=None,
+                attention_mask=None,
+                context=None,
+                mask=None,
+            ):
                 """
                 The attention is similar to the original implementation of LDM CrossAttention class
                 except adding some modifications on the attention
@@ -1523,22 +1824,33 @@ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, ma
                 context = context if is_cross else x
                 k = self.to_k(context)
                 v = self.to_v(context)
-                q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+                q, k, v = map(
+                    lambda t: rearrange(
+                        t, "b n (h d) -> (b h) n d", h=h), (q, k, v)
+                )
 
-                sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
+                sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
 
                 if mask is not None:
-                    mask = rearrange(mask, 'b ... -> b (...)')
+                    mask = rearrange(mask, "b ... -> b (...)")
                     max_neg_value = -torch.finfo(sim.dtype).max
-                    mask = repeat(mask, 'b j -> (b h) () j', h=h)
+                    mask = repeat(mask, "b j -> (b h) () j", h=h)
                     mask = mask[:, None, :].repeat(h, 1, 1)
                     sim.masked_fill_(~mask, max_neg_value)
 
                 attn = sim.softmax(dim=-1)
                 # the only difference
                 out = editor(
-                    q, k, v, sim, attn, is_cross, place_in_unet,
-                    self.heads, scale=self.scale)
+                    q,
+                    k,
+                    v,
+                    sim,
+                    attn,
+                    is_cross,
+                    place_in_unet,
+                    self.heads,
+                    scale=self.scale,
+                )
 
                 return to_out(out)
 
@@ -1546,10 +1858,10 @@ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, ma
 
         def register_editor(net, count, place_in_unet):
             for name, subnet in net.named_children():
-                if net.__class__.__name__ == 'Attention':  # spatial Transformer layer
+                if net.__class__.__name__ == "Attention":  # spatial Transformer layer
                     net.forward = ca_forward(net, place_in_unet)
                     return count + 1
-                elif hasattr(net, 'children'):
+                elif hasattr(net, "children"):
                     count = register_editor(subnet, count, place_in_unet)
             return count
 
@@ -1577,12 +1889,12 @@ def __call__(
         padding_mask_crop: Optional[int] = None,
         strength: float = 0.9999,
         AAS: bool = True,  # AE parameter
-        rm_guidance_scale: float = 7.0, # AE parameter
+        rm_guidance_scale: float = 7.0,  # AE parameter
         ss_steps: int = 9,  # AE parameter
-        ss_scale: float = 0.3, # AE parameter
-        AAS_start_step: int = 0, # AE parameter
-        AAS_start_layer: int = 34, # AE parameter
-        AAS_end_layer: int = 70, # AE parameter
+        ss_scale: float = 0.3,  # AE parameter
+        AAS_start_step: int = 0,  # AE parameter
+        AAS_start_layer: int = 34,  # AE parameter
+        AAS_end_layer: int = 70,  # AE parameter
         num_inference_steps: int = 50,
         timesteps: List[int] = None,
         denoising_start: Optional[float] = None,
@@ -1592,7 +1904,8 @@ def __call__(
         negative_prompt_2: Optional[Union[str, List[str]]] = None,
         num_images_per_prompt: Optional[int] = 1,
         eta: float = 0.0,
-        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        generator: Optional[Union[torch.Generator,
+                                  List[torch.Generator]]] = None,
         latents: Optional[torch.FloatTensor] = None,
         prompt_embeds: Optional[torch.FloatTensor] = None,
         negative_prompt_embeds: Optional[torch.FloatTensor] = None,
@@ -1613,7 +1926,8 @@ def __call__(
         aesthetic_score: float = 6.0,
         negative_aesthetic_score: float = 2.5,
         clip_skip: Optional[int] = None,
-        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end: Optional[Callable[[
+            int, int, Dict], None]] = None,
         callback_on_step_end_tensor_inputs: List[str] = ["latents"],
         **kwargs,
     ):
@@ -1843,7 +2157,7 @@ def __call__(
         self._denoising_start = denoising_start
         self._interrupt = False
 
-        ########### AE parameters
+        # AE parameters
         self._num_timesteps = num_inference_steps
         self._rm_guidance_scale = rm_guidance_scale
         self._AAS = AAS
@@ -1866,7 +2180,9 @@ def __call__(
 
         # 3. Encode input prompt
         text_encoder_lora_scale = (
-            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+            self.cross_attention_kwargs.get("scale", None)
+            if self.cross_attention_kwargs is not None
+            else None
         )
 
         (
@@ -1894,27 +2210,39 @@ def __call__(
         def denoising_value_valid(dnv):
             return isinstance(dnv, float) and 0 < dnv < 1
 
-        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        timesteps, num_inference_steps = retrieve_timesteps(
+            self.scheduler, num_inference_steps, device, timesteps
+        )
         timesteps, num_inference_steps = self.get_timesteps(
             num_inference_steps,
             strength,
             device,
-            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+            denoising_start=(
+                self.denoising_start
+                if denoising_value_valid(self.denoising_start)
+                else None
+            ),
         )
-        # check that number of inference steps is not < 1 - as this doesn't make sense
+        # check that number of inference steps is not < 1 - as this doesn't
+        # make sense
         if num_inference_steps < 1:
             raise ValueError(
                 f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
                 f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
             )
-        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
-        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
-        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+        # at which timestep to set the initial noise (n.b. 50% if strength is
+        # 0.5)
+        latent_timestep = timesteps[:1].repeat(
+            batch_size * num_images_per_prompt)
+        # create a boolean to check if the strength is set to 1. if so then
+        # initialise the latents with pure noise
         is_strength_max = strength == 1.0
 
         # 5. Preprocess mask and image
         if padding_mask_crop is not None:
-            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+            crops_coords = self.mask_processor.get_crop_region(
+                mask_image, width, height, pad=padding_mask_crop
+            )
             resize_mode = "fill"
         else:
             crops_coords = None
@@ -1922,12 +2250,20 @@ def denoising_value_valid(dnv):
 
         original_image = image
         init_image = self.image_processor.preprocess(
-            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+            image,
+            height=height,
+            width=width,
+            crops_coords=crops_coords,
+            resize_mode=resize_mode,
         )
         init_image = init_image.to(dtype=torch.float32)
 
         mask = self.mask_processor.preprocess(
-            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+            mask_image,
+            height=height,
+            width=width,
+            resize_mode=resize_mode,
+            crops_coords=crops_coords,
         )
 
         if masked_image_latents is not None:
@@ -1984,7 +2320,10 @@ def denoising_value_valid(dnv):
             # default case for runwayml/stable-diffusion-inpainting
             num_channels_mask = mask.shape[1]
             num_channels_masked_image = masked_image_latents.shape[1]
-            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+            if (
+                num_channels_latents + num_channels_mask + num_channels_masked_image
+                != self.unet.config.in_channels
+            ):
                 raise ValueError(
                     f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
                     f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
@@ -1999,7 +2338,8 @@ def denoising_value_valid(dnv):
         # 8.1 Prepare extra step kwargs.
         extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
-        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be
+        # moved out of the pipeline
         height, width = latents.shape[-2:]
         height = height * self.vae_scale_factor
         width = width * self.vae_scale_factor
@@ -2031,19 +2371,25 @@ def denoising_value_valid(dnv):
             dtype=prompt_embeds.dtype,
             text_encoder_projection_dim=text_encoder_projection_dim,
         )
-        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+        add_time_ids = add_time_ids.repeat(
+            batch_size * num_images_per_prompt, 1)
 
         if self.do_classifier_free_guidance:
-            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
-            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
-            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+            prompt_embeds = torch.cat(
+                [negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat(
+                [negative_pooled_prompt_embeds, add_text_embeds], dim=0
+            )
+            add_neg_time_ids = add_neg_time_ids.repeat(
+                batch_size * num_images_per_prompt, 1
+            )
             add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
 
-
         ###########
         if self.do_self_attention_redirection_guidance:
             prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
-            add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+            add_text_embeds = torch.cat(
+                [add_text_embeds, add_text_embeds], dim=0)
             add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
             add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
         ############
@@ -2064,13 +2410,24 @@ def denoising_value_valid(dnv):
         # apply AAS to modify the attention module
         if self.do_self_attention_redirection_guidance:
             self._AAS_end_step = int(strength * self._num_timesteps)
-            layer_idx=list(range(self._AAS_start_layer, self._AAS_end_layer))
-            editor = AAS_XL(self._AAS_start_step, self._AAS_end_step, self._AAS_start_layer, self._AAS_end_layer, layer_idx= layer_idx, mask=mask_image,model_type="SDXL",ss_steps=self._ss_steps,ss_scale=self._ss_scale)
+            layer_idx = list(range(self._AAS_start_layer, self._AAS_end_layer))
+            editor = AAS_XL(
+                self._AAS_start_step,
+                self._AAS_end_step,
+                self._AAS_start_layer,
+                self._AAS_end_layer,
+                layer_idx=layer_idx,
+                mask=mask_image,
+                model_type="SDXL",
+                ss_steps=self._ss_steps,
+                ss_scale=self._ss_scale,
+            )
             self.regiter_attention_editor_diffusers(self.unet, editor)
 
-
         # 11. Denoising loop
-        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+        num_warmup_steps = max(
+            len(timesteps) - num_inference_steps * self.scheduler.order, 0
+        )
 
         if (
             self.denoising_end is not None
@@ -2083,20 +2440,26 @@ def denoising_value_valid(dnv):
                 f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
                 + f" {self.denoising_end} when using type float."
             )
-        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+        elif self.denoising_end is not None and denoising_value_valid(
+            self.denoising_end
+        ):
             discrete_timestep_cutoff = int(
                 round(
                     self.scheduler.config.num_train_timesteps
                     - (self.denoising_end * self.scheduler.config.num_train_timesteps)
                 )
             )
-            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+            num_inference_steps = len(
+                list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))
+            )
             timesteps = timesteps[:num_inference_steps]
 
         # 11.1 Optionally get Guidance Scale Embedding
         timestep_cond = None
         if self.unet.config.time_cond_proj_dim is not None:
-            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
+                batch_size * num_images_per_prompt
+            )
             timestep_cond = self.get_guidance_scale_embedding(
                 guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
             ).to(device=device, dtype=latents.dtype)
@@ -2109,21 +2472,37 @@ def denoising_value_valid(dnv):
                     continue
 
                 # expand the latents if we are doing classifier free guidance
-                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+                latent_model_input = (
+                    torch.cat([latents] * 2)
+                    if self.do_classifier_free_guidance
+                    else latents
+                )
 
-                #removal guidance
-                latent_model_input = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents #CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
-                #latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
-                
-                # concat latents, mask, masked_image_latents in the channel dimension
-                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
-                #latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+                # removal guidance
+                latent_model_input = (
+                    torch.cat([latents] * 2)
+                    if self.do_self_attention_redirection_guidance
+                    else latents
+                )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+                # latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+
+                # concat latents, mask, masked_image_latents in the channel
+                # dimension
+                latent_model_input = self.scheduler.scale_model_input(
+                    latent_model_input, t
+                )
+                # latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
 
                 if num_channels_unet == 9:
-                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+                    latent_model_input = torch.cat(
+                        [latent_model_input, mask, masked_image_latents], dim=1
+                    )
 
                 # predict the noise residual
-                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+                added_cond_kwargs = {
+                    "text_embeds": add_text_embeds,
+                    "time_ids": add_time_ids,
+                }
                 if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
                     added_cond_kwargs["image_embeds"] = image_embeds
                 noise_pred = self.unet(
@@ -2145,16 +2524,22 @@ def denoising_value_valid(dnv):
                 # perform guidance
                 if self.do_classifier_free_guidance:
                     noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
-                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (
+                        noise_pred_text - noise_pred_uncond
+                    )
 
                 if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
                     # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
-                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
-
-
+                    noise_pred = rescale_noise_cfg(
+                        noise_pred,
+                        noise_pred_text,
+                        guidance_rescale=self.guidance_rescale,
+                    )
 
                 # compute the previous noisy sample x_t -> x_t-1
-                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+                latents = self.scheduler.step(
+                    noise_pred, t, latents, **extra_step_kwargs, return_dict=False
+                )[0]
 
                 if num_channels_unet == 4:
                     init_latents_proper = image_latents
@@ -2166,28 +2551,42 @@ def denoising_value_valid(dnv):
                     if i < len(timesteps) - 1:
                         noise_timestep = timesteps[i + 1]
                         init_latents_proper = self.scheduler.add_noise(
-                            init_latents_proper, noise, torch.tensor([noise_timestep])
+                            init_latents_proper, noise, torch.tensor(
+                                [noise_timestep])
                         )
 
-                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+                    latents = (
+                        1 - init_mask
+                    ) * init_latents_proper + init_mask * latents
 
                 if callback_on_step_end is not None:
                     callback_kwargs = {}
                     for k in callback_on_step_end_tensor_inputs:
                         callback_kwargs[k] = locals()[k]
-                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+                    callback_outputs = callback_on_step_end(
+                        self, i, t, callback_kwargs)
 
                     latents = callback_outputs.pop("latents", latents)
-                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
-                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
-                    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+                    prompt_embeds = callback_outputs.pop(
+                        "prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop(
+                        "negative_prompt_embeds", negative_prompt_embeds
+                    )
+                    add_text_embeds = callback_outputs.pop(
+                        "add_text_embeds", add_text_embeds
+                    )
                     negative_pooled_prompt_embeds = callback_outputs.pop(
                         "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
                     )
-                    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
-                    add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+                    add_time_ids = callback_outputs.pop(
+                        "add_time_ids", add_time_ids)
+                    add_neg_time_ids = callback_outputs.pop(
+                        "add_neg_time_ids", add_neg_time_ids
+                    )
                     mask = callback_outputs.pop("mask", mask)
-                    masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+                    masked_image_latents = callback_outputs.pop(
+                        "masked_image_latents", masked_image_latents
+                    )
 
                 # call the callback, if provided
                 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -2201,25 +2600,42 @@ def denoising_value_valid(dnv):
 
         if not output_type == "latent":
             # make sure the VAE is in float32 mode, as it overflows in float16
-            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            needs_upcasting = (
+                self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            )
             latents = latents[-1:]
 
             if needs_upcasting:
                 self.upcast_vae()
-                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+                latents = latents.to(
+                    next(iter(self.vae.post_quant_conv.parameters())).dtype
+                )
 
             # unscale/denormalize the latents
             # denormalize with the mean and std if available and not None
-            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
-            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+            has_latents_mean = (
+                hasattr(self.vae.config, "latents_mean")
+                and self.vae.config.latents_mean is not None
+            )
+            has_latents_std = (
+                hasattr(self.vae.config, "latents_std")
+                and self.vae.config.latents_std is not None
+            )
             if has_latents_mean and has_latents_std:
                 latents_mean = (
-                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                    torch.tensor(self.vae.config.latents_mean)
+                    .view(1, 4, 1, 1)
+                    .to(latents.device, latents.dtype)
                 )
                 latents_std = (
-                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                    torch.tensor(self.vae.config.latents_std)
+                    .view(1, 4, 1, 1)
+                    .to(latents.device, latents.dtype)
+                )
+                latents = (
+                    latents * latents_std / self.vae.config.scaling_factor
+                    + latents_mean
                 )
-                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
             else:
                 latents = latents / self.vae.config.scaling_factor
 
@@ -2235,10 +2651,16 @@ def denoising_value_valid(dnv):
         if self.watermark is not None:
             image = self.watermark.apply_watermark(image)
 
-        image = self.image_processor.postprocess(image, output_type=output_type)
+        image = self.image_processor.postprocess(
+            image, output_type=output_type)
 
         if padding_mask_crop is not None:
-            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+            image = [
+                self.image_processor.apply_overlay(
+                    mask_image, original_image, i, crops_coords
+                )
+                for i in image
+            ]
 
         # Offload all models
         self.maybe_free_model_hooks()
From 717e2c711963e587678736b9f572317b90ae235d Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Wed, 15 Jan 2025 17:04:41 +0800
Subject: [PATCH 3/7] make style and add example output
---
 examples/community/README.md                  |   92 +-
 ...ne_stable_diffusion_xl_attentive_eraser.py | 2285 +++++++++++++++++
 2 files changed, 2375 insertions(+), 2 deletions(-)
 mode change 100755 => 100644 examples/community/README.md
 create mode 100644 examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
diff --git a/examples/community/README.md b/examples/community/README.md
old mode 100755
new mode 100644
index c7c40c46ef2d..02777636d2d3
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,6 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
 PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
 | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
 | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3)|
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
@@ -4585,8 +4586,8 @@ image = pipe(
 ```
 
 |  |  |  |
-| ------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- |
-| Gradient                                                                                   | Input                                                                                   | Output                                                                                   |
+| -------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ |
+| Gradient                                                                                     | Input                                                                                     | Output                                                                                     |
 
 A colab notebook demonstrating all results can be found [here](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing). Depth Maps have also been added in the same colab.
 
@@ -4634,6 +4635,93 @@ make_image_grid(image, rows=1, cols=len(image))
 # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
 ```
 
+### Stable Diffusion XL Attentive Eraser Pipeline
+ +
+**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
+
+#### Key features
+
+- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
+- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
+- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
+
+#### Usage example
+To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
+```py
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers.utils import load_image
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+dtype = torch.float16
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
+
+scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+pipeline = DiffusionPipeline.from_pretrained(
+    "stabilityai/stable-diffusion-xl-base-1.0",
+    custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+    scheduler=scheduler,
+    variant="fp16",
+    use_safetensors=True,
+    torch_dtype=dtype,
+).to(device)
+
+
+def preprocess_image(image_path, device):
+    image = to_tensor((load_image(image_path)))
+    image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+    if image.shape[1] != 3:
+        image = image.expand(-1, 3, -1, -1)
+        image = F.interpolate(image, (1024, 1024))
+        image = image.to(dtype).to(device)
+        return image
+
+def preprocess_mask(mask_path, device):
+    mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+    mask = mask.unsqueeze_(0).float()  # 0 or 1
+    mask = F.interpolate(mask, (1024, 1024))
+    mask = gaussian_blur(mask, kernel_size=(77, 77))
+    mask[mask < 0.1] = 0
+    mask[mask >= 0.1] = 1
+    mask = mask.to(dtype).to(device)
+    return mask
+
+prompt = "" # Set prompt to null
+seed=123 
+generator = torch.Generator(device=device).manual_seed(seed)
+source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+source_image = preprocess_image(source_image_path, device)
+mask = preprocess_mask(mask_path, device)
+
+image = pipeline(
+    prompt=prompt, 
+    image=source_image,
+    mask_image=mask,
+    height=1024,
+    width=1024,
+    AAS=True, # enable AAS
+    strength=0.8, # inpainting strength
+    rm_guidance_scale=9, # removal guidance scale
+    ss_steps = 9, # similarity suppression steps
+    ss_scale = 0.3, # similarity suppression scale
+    AAS_start_step=0, # AAS start step
+    AAS_start_layer=34, # AAS start layer
+    AAS_end_layer=70, # AAS end layer
+    num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+    generator=generator,
+    guidance_scale=1,
+).images[0]
+image.save('./removed_img.png')
+print("Object removal completed")
+```
+
+| Source Image                                                                                   | Mask                                                                                        | Output                                                                                              |
+| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
+|  |  |  |
+
 # Perturbed-Attention Guidance
 
 [Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
new file mode 100644
index 000000000000..48c01318991a
--- /dev/null
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -0,0 +1,2285 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from PIL import Image
+from transformers import (
+    CLIPImageProcessor,
+    CLIPTextModel,
+    CLIPTextModelWithProjection,
+    CLIPTokenizer,
+    CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+    FromSingleFileMixin,
+    IPAdapterMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+    AttnProcessor2_0,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    deprecate,
+    is_invisible_watermark_available,
+    is_torch_xla_available,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+
+    XLA_AVAILABLE = True
+else:
+    XLA_AVAILABLE = False
+
+
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> import torch
+        >>> from diffusers import StableDiffusionXLInpaintPipeline
+        >>> from diffusers.utils import load_image
+
+        >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+        ...     "stabilityai/stable-diffusion-xl-base-1.0",
+        ...     torch_dtype=torch.float16,
+        ...     variant="fp16",
+        ...     use_safetensors=True,
+        ... )
+        >>> pipe.to("cuda")
+
+        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+        >>> init_image = load_image(img_url).convert("RGB")
+        >>> mask_image = load_image(mask_url).convert("RGB")
+
+        >>> prompt = "A majestic tiger sitting on a bench"
+        >>> image = pipe(
+        ...     prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+        ... ).images[0]
+        ```
+"""
+
+
+class AttentionBase:
+    def __init__(self):
+        self.cur_step = 0
+        self.num_att_layers = -1
+        self.cur_att_layer = 0
+
+    def after_step(self):
+        pass
+
+    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+        self.cur_att_layer += 1
+        if self.cur_att_layer == self.num_att_layers:
+            self.cur_att_layer = 0
+            self.cur_step += 1
+            # after step
+            self.after_step()
+        return out
+
+    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        out = torch.einsum("b i j, b j d -> b i d", attn, v)
+        out = rearrange(out, "(b h) n d -> b n (h d)", h=num_heads)
+        return out
+
+    def reset(self):
+        self.cur_step = 0
+        self.cur_att_layer = 0
+
+
+class AAS_XL(AttentionBase):
+    MODEL_TYPE = {"SD": 16, "SDXL": 70}
+
+    def __init__(
+        self,
+        start_step=4,
+        end_step=50,
+        start_layer=10,
+        end_layer=16,
+        layer_idx=None,
+        step_idx=None,
+        total_steps=50,
+        mask=None,
+        model_type="SD",
+        ss_steps=9,
+        ss_scale=1.0,
+    ):
+        """
+        Args:
+            start_step: the step to start AAS
+            start_layer: the layer to start AAS
+            layer_idx: list of the layers to apply AAS
+            step_idx: list the steps to apply AAS
+            total_steps: the total number of steps
+            mask: source mask with shape (h, w)
+            model_type: the model type, SD or SDXL
+        """
+        super().__init__()
+        self.total_steps = total_steps
+        self.total_layers = self.MODEL_TYPE.get(model_type, 16)
+        self.start_step = start_step
+        self.end_step = end_step
+        self.start_layer = start_layer
+        self.end_layer = end_layer
+        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
+        self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+        self.mask = mask  # mask with shape (1, 1 ,h, w)
+        self.ss_steps = ss_steps
+        self.ss_scale = ss_scale
+        print("AAS at denoising steps: ", self.step_idx)
+        print("AAS at U-Net layers: ", self.layer_idx)
+        print("start AAS")
+        self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
+        self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
+        self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
+        self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()
+
+    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):
+        B = q.shape[0] // num_heads
+        if is_mask_attn:
+            mask_flatten = mask.flatten(0)
+            if self.cur_step <= self.ss_steps:
+                # background
+                sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+                # object
+                sim_fg = self.ss_scale * sim
+                sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+                sim = torch.cat([sim_fg, sim_bg], dim=0)
+            else:
+                sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+        attn = sim.softmax(-1)
+        if len(attn) == 2 * len(v):
+            v = torch.cat([v] * 2)
+        out = torch.einsum("h i j, h j d -> h i d", attn, v)
+        out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+        return out
+
+    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        """
+        Attention forward function
+        """
+        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+        # B = q.shape[0] // num_heads // 2
+        H = int(np.sqrt(q.shape[1]))
+        # H = W = int(np.sqrt(q.shape[1]))
+        if H == 16:
+            mask = self.mask_16.to(sim.device)
+        elif H == 32:
+            mask = self.mask_32.to(sim.device)
+        elif H == 64:
+            mask = self.mask_64.to(sim.device)
+        else:
+            mask = self.mask_128.to(sim.device)
+
+        q_wo, q_w = q.chunk(2)
+        k_wo, k_w = k.chunk(2)
+        v_wo, v_w = v.chunk(2)
+        sim_wo, sim_w = sim.chunk(2)
+        attn_wo, attn_w = attn.chunk(2)
+
+        out_source = self.attn_batch(
+            q_wo,
+            k_wo,
+            v_wo,
+            sim_wo,
+            attn_wo,
+            is_cross,
+            place_in_unet,
+            num_heads,
+            is_mask_attn=False,
+            mask=None,
+            **kwargs,
+        )
+        out_target = self.attn_batch(
+            q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask=mask, **kwargs
+        )
+
+        if self.mask is not None:
+            if out_target.shape[0] == 2:
+                out_target_fg, out_target_bg = out_target.chunk(2, 0)
+                mask = mask.reshape(-1, 1)  # (hw, 1)
+                out_target = out_target_fg * mask + out_target_bg * (1 - mask)
+            else:
+                out_target = out_target
+
+        out = torch.cat([out_source, out_target], dim=0)
+        return out
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+    """
+    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+    """
+    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+    # rescale the results from guidance (fixes overexposure)
+    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+    return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+    # preprocess mask
+    if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+        mask = [mask]
+
+    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+        mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+        mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+        mask = mask.astype(np.float32) / 255.0
+    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+    mask = torch.from_numpy(mask)
+    return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+    """
+    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+    ``image`` and ``1`` for the ``mask``.
+
+    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+    Args:
+        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+    Raises:
+        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+            (ot the other way around).
+
+    Returns:
+        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+            dimensions: ``batch x channels x height x width``.
+    """
+
+    # checkpoint. TOD(Yiyi) - need to clean this up later
+    deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+    deprecate(
+        "prepare_mask_and_masked_image",
+        "0.30.0",
+        deprecation_message,
+    )
+    if image is None:
+        raise ValueError("`image` input cannot be undefined.")
+
+    if mask is None:
+        raise ValueError("`mask_image` input cannot be undefined.")
+
+    if isinstance(image, torch.Tensor):
+        if not isinstance(mask, torch.Tensor):
+            mask = mask_pil_to_torch(mask, height, width)
+
+        if image.ndim == 3:
+            image = image.unsqueeze(0)
+
+        # Batch and add channel dim for single mask
+        if mask.ndim == 2:
+            mask = mask.unsqueeze(0).unsqueeze(0)
+
+        # Batch single mask or add channel dim
+        if mask.ndim == 3:
+            # Single batched mask, no channel dim or single mask not batched but channel dim
+            if mask.shape[0] == 1:
+                mask = mask.unsqueeze(0)
+
+            # Batched masks no channel dim
+            else:
+                mask = mask.unsqueeze(1)
+
+        assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+        # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+        assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+        # Check image is in [-1, 1]
+        # if image.min() < -1 or image.max() > 1:
+        #    raise ValueError("Image should be in [-1, 1] range")
+
+        # Check mask is in [0, 1]
+        if mask.min() < 0 or mask.max() > 1:
+            raise ValueError("Mask should be in [0, 1] range")
+
+        # Binarize mask
+        mask[mask < 0.5] = 0
+        mask[mask >= 0.5] = 1
+
+        # Image as float32
+        image = image.to(dtype=torch.float32)
+    elif isinstance(mask, torch.Tensor):
+        raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+    else:
+        # preprocess image
+        if isinstance(image, (PIL.Image.Image, np.ndarray)):
+            image = [image]
+        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+            # resize all images w.r.t passed height an width
+            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+            image = [np.array(i.convert("RGB"))[None, :] for i in image]
+            image = np.concatenate(image, axis=0)
+        elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+            image = np.concatenate([i[None, :] for i in image], axis=0)
+
+        image = image.transpose(0, 3, 1, 2)
+        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+        mask = mask_pil_to_torch(mask, height, width)
+        mask[mask < 0.5] = 0
+        mask[mask >= 0.5] = 1
+
+    if image.shape[1] == 4:
+        # images are in latent space and thus can't
+        # be masked set masked_image to None
+        # we assume that the checkpoint is not an inpainting
+        # checkpoint. TOD(Yiyi) - need to clean this up later
+        masked_image = None
+    else:
+        masked_image = image * (mask < 0.5)
+
+    # n.b. ensure backwards compatibility as old function does not return image
+    if return_image:
+        return mask, masked_image, image
+
+    return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+        return encoder_output.latent_dist.sample(generator)
+    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+        return encoder_output.latent_dist.mode()
+    elif hasattr(encoder_output, "latents"):
+        return encoder_output.latents
+    else:
+        raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used,
+            `timesteps` must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+                must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+class StableDiffusionXL_AE_Pipeline(
+    DiffusionPipeline,
+    StableDiffusionMixin,
+    TextualInversionLoaderMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    FromSingleFileMixin,
+    IPAdapterMixin,
+):
+    r"""
+    Pipeline for object removal using Stable Diffusion XL.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    The pipeline also inherits the following loading methods:
+        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion XL uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        text_encoder_2 ([` CLIPTextModelWithProjection`]):
+            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+            specifically the
+            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+            variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        tokenizer_2 (`CLIPTokenizer`):
+            Second Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+            Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+            of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+            `stabilityai/stable-diffusion-xl-base-1-0`.
+        add_watermarker (`bool`, *optional*):
+            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+            watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+            watermarker will be used.
+    """
+
+    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+
+    _optional_components = [
+        "tokenizer",
+        "tokenizer_2",
+        "text_encoder",
+        "text_encoder_2",
+        "image_encoder",
+        "feature_extractor",
+    ]
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+        "add_text_embeds",
+        "add_time_ids",
+        "negative_pooled_prompt_embeds",
+        "add_neg_time_ids",
+        "mask",
+        "masked_image_latents",
+    ]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        text_encoder_2: CLIPTextModelWithProjection,
+        tokenizer: CLIPTokenizer,
+        tokenizer_2: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        scheduler: KarrasDiffusionSchedulers,
+        image_encoder: CLIPVisionModelWithProjection = None,
+        feature_extractor: CLIPImageProcessor = None,
+        requires_aesthetics_score: bool = False,
+        force_zeros_for_empty_prompt: bool = True,
+        add_watermarker: Optional[bool] = None,
+    ):
+        super().__init__()
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            text_encoder_2=text_encoder_2,
+            tokenizer=tokenizer,
+            tokenizer_2=tokenizer_2,
+            unet=unet,
+            image_encoder=image_encoder,
+            feature_extractor=feature_extractor,
+            scheduler=scheduler,
+        )
+        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.mask_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+        )
+
+        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+        if add_watermarker:
+            self.watermark = StableDiffusionXLWatermarker()
+        else:
+            self.watermark = None
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        if not isinstance(image, torch.Tensor):
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+        if output_hidden_states:
+            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_enc_hidden_states = self.image_encoder(
+                torch.zeros_like(image), output_hidden_states=True
+            ).hidden_states[-2]
+            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
+            return image_enc_hidden_states, uncond_image_enc_hidden_states
+        else:
+            image_embeds = self.image_encoder(image).image_embeds
+            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_embeds = torch.zeros_like(image_embeds)
+
+            return image_embeds, uncond_image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+    ):
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+                )
+
+            image_embeds = []
+            for single_ip_adapter_image, image_proj_layer in zip(
+                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+            ):
+                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+                single_image_embeds, single_negative_image_embeds = self.encode_image(
+                    single_ip_adapter_image, device, 1, output_hidden_state
+                )
+                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+                single_negative_image_embeds = torch.stack(
+                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
+                )
+
+                if do_classifier_free_guidance:
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                    single_image_embeds = single_image_embeds.to(device)
+
+                image_embeds.append(single_image_embeds)
+        else:
+            repeat_dims = [1]
+            image_embeds = []
+            for single_image_embeds in ip_adapter_image_embeds:
+                if do_classifier_free_guidance:
+                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                    single_negative_image_embeds = single_negative_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+                    )
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                else:
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                image_embeds.append(single_image_embeds)
+
+        return image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+    def encode_prompt(
+        self,
+        prompt: str,
+        prompt_2: Optional[str] = None,
+        device: Optional[torch.device] = None,
+        num_images_per_prompt: int = 1,
+        do_classifier_free_guidance: bool = True,
+        negative_prompt: Optional[str] = None,
+        negative_prompt_2: Optional[str] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        clip_skip: Optional[int] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            lora_scale (`float`, *optional*):
+                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+        """
+        device = device or self._execution_device
+
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+            # dynamically adjust the LoRA scale
+            if self.text_encoder is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder, lora_scale)
+
+            if self.text_encoder_2 is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder_2, lora_scale)
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        if prompt_embeds is None:
+            prompt_2 = prompt_2 or prompt
+            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+            # textual inversion: process multi-vector tokens if necessary
+            prompt_embeds_list = []
+            prompts = [prompt, prompt_2]
+            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+                text_inputs = tokenizer(
+                    prompt,
+                    padding="max_length",
+                    max_length=tokenizer.model_max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                text_input_ids = text_inputs.input_ids
+                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                    text_input_ids, untruncated_ids
+                ):
+                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                    logger.warning(
+                        "The following part of your input was truncated because CLIP can only handle sequences up to"
+                        f" {tokenizer.model_max_length} tokens: {removed_text}"
+                    )
+
+                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                pooled_prompt_embeds = prompt_embeds[0]
+                if clip_skip is None:
+                    prompt_embeds = prompt_embeds.hidden_states[-2]
+                else:
+                    # "2" because SDXL always indexes from the penultimate layer.
+                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+                prompt_embeds_list.append(prompt_embeds)
+
+            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+        # get unconditional embeddings for classifier free guidance
+        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+            negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+        elif do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+            # normalize str to list
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+            negative_prompt_2 = (
+                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+            )
+
+            uncond_tokens: List[str]
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = [negative_prompt, negative_prompt_2]
+
+            negative_prompt_embeds_list = []
+            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+                max_length = prompt_embeds.shape[1]
+                uncond_input = tokenizer(
+                    negative_prompt,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                negative_prompt_embeds = text_encoder(
+                    uncond_input.input_ids.to(device),
+                    output_hidden_states=True,
+                )
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+                negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+        if self.text_encoder_2 is not None:
+            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+        else:
+            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            if self.text_encoder_2 is not None:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+            else:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+        if do_classifier_free_guidance:
+            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+                bs_embed * num_images_per_prompt, -1
+            )
+
+        if self.text_encoder is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder, lora_scale)
+
+        if self.text_encoder_2 is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def check_inputs(
+        self,
+        prompt,
+        prompt_2,
+        image,
+        mask_image,
+        height,
+        width,
+        strength,
+        callback_steps,
+        output_type,
+        negative_prompt=None,
+        negative_prompt_2=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        ip_adapter_image=None,
+        ip_adapter_image_embeds=None,
+        callback_on_step_end_tensor_inputs=None,
+        padding_mask_crop=None,
+    ):
+        if strength < 0 or strength > 1:
+            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt_2 is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+            raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+        if padding_mask_crop is not None:
+            if not isinstance(image, PIL.Image.Image):
+                raise ValueError(
+                    f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+                )
+            if not isinstance(mask_image, PIL.Image.Image):
+                raise ValueError(
+                    f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+                    f" {type(mask_image)}."
+                )
+            if output_type != "pil":
+                raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+            raise ValueError(
+                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+            )
+
+        if ip_adapter_image_embeds is not None:
+            if not isinstance(ip_adapter_image_embeds, list):
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+                )
+            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+                )
+
+    def prepare_latents(
+        self,
+        batch_size,
+        num_channels_latents,
+        height,
+        width,
+        dtype,
+        device,
+        generator,
+        latents=None,
+        image=None,
+        timestep=None,
+        is_strength_max=True,
+        add_noise=True,
+        return_noise=False,
+        return_image_latents=False,
+    ):
+        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if (image is None or timestep is None) and not is_strength_max:
+            raise ValueError(
+                "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+                "However, either the image or the noise timestep has not been provided."
+            )
+
+        if image.shape[1] == 4:
+            image_latents = image.to(device=device, dtype=dtype)
+            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+        elif return_image_latents or (latents is None and not is_strength_max):
+            image = image.to(device=device, dtype=dtype)
+            image_latents = self._encode_vae_image(image=image, generator=generator)
+            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+        if latents is None and add_noise:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            # if strength is 1. then initialise the latents to noise, else initial to image + noise
+            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+            # if pure noise then scale the initial latents by the  Scheduler's init sigma
+            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+        elif add_noise:
+            noise = latents.to(device)
+            latents = noise * self.scheduler.init_noise_sigma
+        else:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            latents = image_latents.to(device)
+
+        outputs = (latents,)
+
+        if return_noise:
+            outputs += (noise,)
+
+        if return_image_latents:
+            outputs += (image_latents,)
+
+        return outputs
+
+    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+        dtype = image.dtype
+        if self.vae.config.force_upcast:
+            image = image.float()
+            self.vae.to(dtype=torch.float32)
+
+        if isinstance(generator, list):
+            image_latents = [
+                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+                for i in range(image.shape[0])
+            ]
+            image_latents = torch.cat(image_latents, dim=0)
+        else:
+            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+        if self.vae.config.force_upcast:
+            self.vae.to(dtype)
+
+        image_latents = image_latents.to(dtype)
+        image_latents = self.vae.config.scaling_factor * image_latents
+
+        return image_latents
+
+    def prepare_mask_latents(
+        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+    ):
+        # resize the mask to latents shape as we concatenate the mask to the latents
+        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+        # and half precision
+        # mask = torch.nn.functional.interpolate(
+        #    mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+        # )
+        mask = torch.nn.functional.max_pool2d(mask, (8, 8)).round()
+        mask = mask.to(device=device, dtype=dtype)
+
+        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+        if mask.shape[0] < batch_size:
+            if not batch_size % mask.shape[0] == 0:
+                raise ValueError(
+                    "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+                    f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+                    " of masks that you pass is divisible by the total requested batch size."
+                )
+            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+        if masked_image is not None and masked_image.shape[1] == 4:
+            masked_image_latents = masked_image
+        else:
+            masked_image_latents = None
+
+        if masked_image is not None:
+            if masked_image_latents is None:
+                masked_image = masked_image.to(device=device, dtype=dtype)
+                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+            if masked_image_latents.shape[0] < batch_size:
+                if not batch_size % masked_image_latents.shape[0] == 0:
+                    raise ValueError(
+                        "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+                        f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+                        " Make sure the number of images that you pass is divisible by the total requested batch size."
+                    )
+                masked_image_latents = masked_image_latents.repeat(
+                    batch_size // masked_image_latents.shape[0], 1, 1, 1
+                )
+
+            masked_image_latents = (
+                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+            )
+
+            # aligning device to prevent device errors when concating it with the latent model input
+            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+        return mask, masked_image_latents
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+        # get the original timestep using init_timestep
+        if denoising_start is None:
+            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+            t_start = max(num_inference_steps - init_timestep, 0)
+        else:
+            t_start = 0
+
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+        # Strength is irrelevant if we directly request a timestep to start at;
+        # that is, strength is determined by the denoising_start instead.
+        if denoising_start is not None:
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (denoising_start * self.scheduler.config.num_train_timesteps)
+                )
+            )
+
+            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+                # if the scheduler is a 2nd order scheduler we might have to do +1
+                # because `num_inference_steps` might be even given that every timestep
+                # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+                # mean that we cut the timesteps in the middle of the denoising step
+                # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+                num_inference_steps = num_inference_steps + 1
+
+            # because t_n+1 >= t_n, we slice the timesteps starting from the end
+            timesteps = timesteps[-num_inference_steps:]
+            return timesteps, num_inference_steps
+
+        return timesteps, num_inference_steps - t_start
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+    def _get_add_time_ids(
+        self,
+        original_size,
+        crops_coords_top_left,
+        target_size,
+        aesthetic_score,
+        negative_aesthetic_score,
+        negative_original_size,
+        negative_crops_coords_top_left,
+        negative_target_size,
+        dtype,
+        text_encoder_projection_dim=None,
+    ):
+        if self.config.requires_aesthetics_score:
+            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+            add_neg_time_ids = list(
+                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+            )
+        else:
+            add_time_ids = list(original_size + crops_coords_top_left + target_size)
+            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+        passed_add_embed_dim = (
+            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+        )
+        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+        if (
+            expected_add_embed_dim > passed_add_embed_dim
+            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+            )
+        elif (
+            expected_add_embed_dim < passed_add_embed_dim
+            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+            )
+        elif expected_add_embed_dim != passed_add_embed_dim:
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+            )
+
+        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+        return add_time_ids, add_neg_time_ids
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+    def upcast_vae(self):
+        dtype = self.vae.dtype
+        self.vae.to(dtype=torch.float32)
+        use_torch_2_0_or_xformers = isinstance(
+            self.vae.decoder.mid_block.attentions[0].processor,
+            (
+                AttnProcessor2_0,
+                XFormersAttnProcessor,
+                LoRAXFormersAttnProcessor,
+                LoRAAttnProcessor2_0,
+            ),
+        )
+        # if xformers or torch_2_0 is used attention block does not need
+        # to be in float32 which can save lots of memory
+        if use_torch_2_0_or_xformers:
+            self.vae.post_quant_conv.to(dtype)
+            self.vae.decoder.conv_in.to(dtype)
+            self.vae.decoder.mid_block.to(dtype)
+
+    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+        """
+        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+        Args:
+            timesteps (`torch.Tensor`):
+                generate embedding vectors at these timesteps
+            embedding_dim (`int`, *optional*, defaults to 512):
+                dimension of the embeddings to generate
+            dtype:
+                data type of the generated embeddings
+
+        Returns:
+            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+        """
+        assert len(w.shape) == 1
+        w = w * 1000.0
+
+        half_dim = embedding_dim // 2
+        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+        emb = w.to(dtype)[:, None] * emb[None, :]
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+        if embedding_dim % 2 == 1:  # zero pad
+            emb = torch.nn.functional.pad(emb, (0, 1))
+        assert emb.shape == (w.shape[0], embedding_dim)
+        return emb
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def guidance_rescale(self):
+        return self._guidance_rescale
+
+    @property
+    def clip_skip(self):
+        return self._clip_skip
+
+    @property
+    def do_self_attention_redirection_guidance(self):  # SARG
+        return self._rm_guidance_scale > 1 and self._AAS
+
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return (
+            self._guidance_scale > 1
+            and self.unet.config.time_cond_proj_dim is None
+            and not self.do_self_attention_redirection_guidance
+        )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+
+    @property
+    def cross_attention_kwargs(self):
+        return self._cross_attention_kwargs
+
+    @property
+    def denoising_end(self):
+        return self._denoising_end
+
+    @property
+    def denoising_start(self):
+        return self._denoising_start
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    def image2latent(self, image: torch.Tensor, generator: torch.Generator):
+        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        if type(image) is Image:
+            image = np.array(image)
+            image = torch.from_numpy(image).float() / 127.5 - 1
+            image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
+        # input image density range [-1, 1]
+        # latents = self.vae.encode(image)['latent_dist'].mean
+        latents = self._encode_vae_image(image, generator)
+        # latents = retrieve_latents(self.vae.encode(image))
+        # latents = latents * self.vae.config.scaling_factor
+        return latents
+
+    def next_step(self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0.0, verbose=False):
+        """
+        Inverse sampling for DDIM Inversion
+        """
+        if verbose:
+            print("timestep: ", timestep)
+        next_step = timestep
+        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
+        beta_prod_t = 1 - alpha_prod_t
+        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+        pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output
+        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
+        return x_next, pred_x0
+
+    @torch.no_grad()
+    def invert(
+        self,
+        image: torch.Tensor,
+        prompt,
+        num_inference_steps=50,
+        eta=0.0,
+        original_size: Tuple[int, int] = None,
+        target_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        return_intermediates=False,
+        **kwds,
+    ):
+        """
+        invert a real image into noise map with determinisc DDIM inversion
+        """
+        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        batch_size = image.shape[0]
+        if isinstance(prompt, list):
+            if batch_size == 1:
+                image = image.expand(len(prompt), -1, -1, -1)
+        elif isinstance(prompt, str):
+            if batch_size > 1:
+                prompt = [prompt] * batch_size
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        prompt_2 = prompt
+        prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+        # textual inversion: process multi-vector tokens if necessary
+        prompt_embeds_list = []
+        prompts = [prompt, prompt_2]
+        for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+            if isinstance(self, TextualInversionLoaderMixin):
+                prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+            text_inputs = tokenizer(
+                prompt,
+                padding="max_length",
+                max_length=tokenizer.model_max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+
+            text_input_ids = text_inputs.input_ids
+            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                text_input_ids, untruncated_ids
+            ):
+                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                logger.warning(
+                    "The following part of your input was truncated because CLIP can only handle sequences up to"
+                    f" {tokenizer.model_max_length} tokens: {removed_text}"
+                )
+
+            prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+
+            # We are only ALWAYS interested in the pooled output of the final text encoder
+            pooled_prompt_embeds = prompt_embeds[0]
+            prompt_embeds = prompt_embeds.hidden_states[-2]
+            prompt_embeds_list.append(prompt_embeds)
+        prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+        prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
+
+        # define initial latents
+        latents = self.image2latent(image, generator=None)
+
+        start_latents = latents
+        height, width = latents.shape[-2:]
+        height = height * self.vae_scale_factor
+        width = width * self.vae_scale_factor
+
+        original_size = (height, width)
+        target_size = (height, width)
+        negative_original_size = original_size
+        negative_target_size = target_size
+
+        add_text_embeds = pooled_prompt_embeds
+        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+
+        add_time_ids = add_time_ids.repeat(batch_size, 1).to(DEVICE)
+
+        # interative sampling
+        self.scheduler.set_timesteps(num_inference_steps)
+        latents_list = [latents]
+        pred_x0_list = []
+        # for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+        for i, t in enumerate(reversed(self.scheduler.timesteps)):
+            model_inputs = latents
+
+            # predict the noise
+            added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+            noise_pred = self.unet(
+                model_inputs, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs
+            ).sample
+
+            # compute the previous noise sample x_t-1 -> x_t
+            latents, pred_x0 = self.next_step(noise_pred, t, latents)
+            """
+            if t >= 1 and t < 41:
+                latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
+            else:
+                latents, pred_x0 = self.next_step(noise_pred, t, latents) """
+
+            latents_list.append(latents)
+            pred_x0_list.append(pred_x0)
+
+        if return_intermediates:
+            # return the intermediate laters during inversion
+            # pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+            # latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+            return latents, latents_list, pred_x0_list
+        return latents, start_latents
+
+    def opt(
+        self,
+        model_output: torch.FloatTensor,
+        timestep: int,
+        x: torch.FloatTensor,
+    ):
+        """
+        predict the sampe the next step in the denoise process.
+        """
+        ref_noise = model_output[:1, :, :, :].expand(model_output.shape)
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+        beta_prod_t = 1 - alpha_prod_t
+        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+        x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t) ** 0.5 * ref_noise
+        return x_opt, pred_x0
+
+    def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
+        """
+        Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
+        """
+
+        def ca_forward(self, place_in_unet):
+            def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+                """
+                The attention is similar to the original implementation of LDM CrossAttention class
+                except adding some modifications on the attention
+                """
+                if encoder_hidden_states is not None:
+                    context = encoder_hidden_states
+                if attention_mask is not None:
+                    mask = attention_mask
+
+                to_out = self.to_out
+                if isinstance(to_out, nn.modules.container.ModuleList):
+                    to_out = self.to_out[0]
+                else:
+                    to_out = self.to_out
+
+                h = self.heads
+                q = self.to_q(x)
+                is_cross = context is not None
+                context = context if is_cross else x
+                k = self.to_k(context)
+                v = self.to_v(context)
+                # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+                q, k, v = (rearrange(t, "b n (h d) -> (b h) n d", h=h) for t in (q, k, v))
+
+                sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+                if mask is not None:
+                    mask = rearrange(mask, "b ... -> b (...)")
+                    max_neg_value = -torch.finfo(sim.dtype).max
+                    mask = repeat(mask, "b j -> (b h) () j", h=h)
+                    mask = mask[:, None, :].repeat(h, 1, 1)
+                    sim.masked_fill_(~mask, max_neg_value)
+
+                attn = sim.softmax(dim=-1)
+                # the only difference
+                out = editor(q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale)
+
+                return to_out(out)
+
+            return forward
+
+        def register_editor(net, count, place_in_unet):
+            for name, subnet in net.named_children():
+                if net.__class__.__name__ == "Attention":  # spatial Transformer layer
+                    net.forward = ca_forward(net, place_in_unet)
+                    return count + 1
+                elif hasattr(net, "children"):
+                    count = register_editor(subnet, count, place_in_unet)
+            return count
+
+        cross_att_count = 0
+        for net_name, net in unet.named_children():
+            if "down" in net_name:
+                cross_att_count += register_editor(net, 0, "down")
+            elif "mid" in net_name:
+                cross_att_count += register_editor(net, 0, "mid")
+            elif "up" in net_name:
+                cross_att_count += register_editor(net, 0, "up")
+        editor.num_att_layers = cross_att_count
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        prompt_2: Optional[Union[str, List[str]]] = None,
+        image: PipelineImageInput = None,
+        mask_image: PipelineImageInput = None,
+        masked_image_latents: torch.FloatTensor = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        padding_mask_crop: Optional[int] = None,
+        strength: float = 0.9999,
+        AAS: bool = True,  # AE parameter
+        rm_guidance_scale: float = 7.0,  # AE parameter
+        ss_steps: int = 9,  # AE parameter
+        ss_scale: float = 0.3,  # AE parameter
+        AAS_start_step: int = 0,  # AE parameter
+        AAS_start_layer: int = 34,  # AE parameter
+        AAS_end_layer: int = 70,  # AE parameter
+        num_inference_steps: int = 50,
+        timesteps: List[int] = None,
+        denoising_start: Optional[float] = None,
+        denoising_end: Optional[float] = None,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt_2: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        guidance_rescale: float = 0.0,
+        original_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Tuple[int, int] = None,
+        negative_original_size: Optional[Tuple[int, int]] = None,
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_target_size: Optional[Tuple[int, int]] = None,
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        **kwargs,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            image (`PIL.Image.Image`):
+                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+                be masked out with `mask_image` and repainted according to `prompt`.
+            mask_image (`PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+                instead of 3, so the expected shape would be `(B, H, W, 1)`.
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image. This is set to 1024 by default for the best results.
+                Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image. This is set to 1024 by default for the best results.
+                Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            padding_mask_crop (`int`, *optional*, defaults to `None`):
+                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+                `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+                contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+                the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+                and contain information inreleant for inpainging, such as background.
+            strength (`float`, *optional*, defaults to 0.9999):
+                Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+                between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+                `strength`. The number of denoising steps depends on the amount of noise initially added. When
+                `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+                iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+                portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+                integer, the value of `strength` will be ignored.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            denoising_start (`float`, *optional*):
+                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+                is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+            denoising_end (`float`, *optional*):
+                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+                completed before it is intentionally prematurely terminated. As a result, the returned sample will
+                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+                forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+                if `do_classifier_free_guidance` is set to `True`.
+                If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+                explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                For most cases, `target_size` should be set to the desired height and width of the generated image. If
+                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a target image resolution. It should be as same
+                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            aesthetic_score (`float`, *optional*, defaults to 6.0):
+                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+                simulate an aesthetic score of the generated image by influencing the negative text condition.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+            )
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 1. Check inputs
+        self.check_inputs(
+            prompt,
+            prompt_2,
+            image,
+            mask_image,
+            height,
+            width,
+            strength,
+            callback_steps,
+            output_type,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds,
+            negative_prompt_embeds,
+            ip_adapter_image,
+            ip_adapter_image_embeds,
+            callback_on_step_end_tensor_inputs,
+            padding_mask_crop,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._guidance_rescale = guidance_rescale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+        self._denoising_end = denoising_end
+        self._denoising_start = denoising_start
+        self._interrupt = False
+
+        ########### AE parameters
+        self._num_timesteps = num_inference_steps
+        self._rm_guidance_scale = rm_guidance_scale
+        self._AAS = AAS
+        self._ss_steps = ss_steps
+        self._ss_scale = ss_scale
+        self._AAS_start_step = AAS_start_step
+        self._AAS_start_layer = AAS_start_layer
+        self._AAS_end_layer = AAS_end_layer
+        ###########
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # 3. Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            prompt_2=prompt_2,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=self.do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            negative_prompt_2=negative_prompt_2,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        # 4. set timesteps
+        def denoising_value_valid(dnv):
+            return isinstance(dnv, float) and 0 < dnv < 1
+
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        timesteps, num_inference_steps = self.get_timesteps(
+            num_inference_steps,
+            strength,
+            device,
+            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+        )
+        # check that number of inference steps is not < 1 - as this doesn't make sense
+        if num_inference_steps < 1:
+            raise ValueError(
+                f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+                f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+            )
+        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+        is_strength_max = strength == 1.0
+
+        # 5. Preprocess mask and image
+        if padding_mask_crop is not None:
+            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+            resize_mode = "fill"
+        else:
+            crops_coords = None
+            resize_mode = "default"
+
+        original_image = image
+        init_image = self.image_processor.preprocess(
+            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+        )
+        init_image = init_image.to(dtype=torch.float32)
+
+        mask = self.mask_processor.preprocess(
+            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+        )
+
+        if masked_image_latents is not None:
+            masked_image = masked_image_latents
+        elif init_image.shape[1] == 4:
+            # if images are in latent space, we can't mask it
+            masked_image = None
+        else:
+            masked_image = init_image * (mask < 0.5)
+
+        # 6. Prepare latent variables
+        num_channels_latents = self.vae.config.latent_channels
+        num_channels_unet = self.unet.config.in_channels
+        return_image_latents = num_channels_unet == 4
+
+        add_noise = True if self.denoising_start is None else False
+        latents_outputs = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+            image=init_image,
+            timestep=latent_timestep,
+            is_strength_max=is_strength_max,
+            add_noise=add_noise,
+            return_noise=True,
+            return_image_latents=return_image_latents,
+        )
+
+        if return_image_latents:
+            latents, noise, image_latents = latents_outputs
+        else:
+            latents, noise = latents_outputs
+
+        # 7. Prepare mask latent variables
+        mask, masked_image_latents = self.prepare_mask_latents(
+            mask,
+            masked_image,
+            batch_size * num_images_per_prompt,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            self.do_classifier_free_guidance,
+        )
+
+        # 8. Check that sizes of mask, masked image and latents match
+        if num_channels_unet == 9:
+            # default case for runwayml/stable-diffusion-inpainting
+            num_channels_mask = mask.shape[1]
+            num_channels_masked_image = masked_image_latents.shape[1]
+            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+                raise ValueError(
+                    f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+                    f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+                    f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+                    f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+                    " `pipeline.unet` or your `mask_image` or `image` input."
+                )
+        elif num_channels_unet != 4:
+            raise ValueError(
+                f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+            )
+        # 8.1 Prepare extra step kwargs.
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        height, width = latents.shape[-2:]
+        height = height * self.vae_scale_factor
+        width = width * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        # 10. Prepare added time ids & embeddings
+        if negative_original_size is None:
+            negative_original_size = original_size
+        if negative_target_size is None:
+            negative_target_size = target_size
+
+        add_text_embeds = pooled_prompt_embeds
+        if self.text_encoder_2 is None:
+            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        else:
+            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+        ###########
+        if self.do_self_attention_redirection_guidance:
+            prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+        ############
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device)
+
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+                self.do_classifier_free_guidance,
+            )
+
+        # apply AAS to modify the attention module
+        if self.do_self_attention_redirection_guidance:
+            self._AAS_end_step = int(strength * self._num_timesteps)
+            layer_idx = list(range(self._AAS_start_layer, self._AAS_end_layer))
+            editor = AAS_XL(
+                self._AAS_start_step,
+                self._AAS_end_step,
+                self._AAS_start_layer,
+                self._AAS_end_layer,
+                layer_idx=layer_idx,
+                mask=mask_image,
+                model_type="SDXL",
+                ss_steps=self._ss_steps,
+                ss_scale=self._ss_scale,
+            )
+            self.regiter_attention_editor_diffusers(self.unet, editor)
+
+        # 11. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        if (
+            self.denoising_end is not None
+            and self.denoising_start is not None
+            and denoising_value_valid(self.denoising_end)
+            and denoising_value_valid(self.denoising_start)
+            and self.denoising_start >= self.denoising_end
+        ):
+            raise ValueError(
+                f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+                + f" {self.denoising_end} when using type float."
+            )
+        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+                )
+            )
+            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+            timesteps = timesteps[:num_inference_steps]
+
+        # 11.1 Optionally get Guidance Scale Embedding
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        self._num_timesteps = len(timesteps)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+                # removal guidance
+                latent_model_input = (
+                    torch.cat([latents] * 2) if self.do_self_attention_redirection_guidance else latents
+                )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+                # latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+
+                # concat latents, mask, masked_image_latents in the channel dimension
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+                # latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+
+                if num_channels_unet == 9:
+                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+                # predict the noise residual
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+                    added_cond_kwargs["image_embeds"] = image_embeds
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep_cond=timestep_cond,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform SARG
+                if self.do_self_attention_redirection_guidance:
+                    noise_pred_wo, noise_pred_w = noise_pred.chunk(2)
+                    delta = noise_pred_w - noise_pred_wo
+                    noise_pred = noise_pred_wo + self._rm_guidance_scale * delta
+
+                # perform guidance
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if num_channels_unet == 4:
+                    init_latents_proper = image_latents
+                    if self.do_classifier_free_guidance:
+                        init_mask, _ = mask.chunk(2)
+                    else:
+                        init_mask = mask
+
+                    if i < len(timesteps) - 1:
+                        noise_timestep = timesteps[i + 1]
+                        init_latents_proper = self.scheduler.add_noise(
+                            init_latents_proper, noise, torch.tensor([noise_timestep])
+                        )
+
+                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+                    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+                    negative_pooled_prompt_embeds = callback_outputs.pop(
+                        "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+                    )
+                    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+                    add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+                    mask = callback_outputs.pop("mask", mask)
+                    masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+                if XLA_AVAILABLE:
+                    xm.mark_step()
+
+        if not output_type == "latent":
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            latents = latents[-1:]
+
+            if needs_upcasting:
+                self.upcast_vae()
+                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+            # unscale/denormalize the latents
+            # denormalize with the mean and std if available and not None
+            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+            if has_latents_mean and has_latents_std:
+                latents_mean = (
+                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents_std = (
+                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+            else:
+                latents = latents / self.vae.config.scaling_factor
+
+            image = self.vae.decode(latents, return_dict=False)[0]
+
+            # cast back to fp16 if needed
+            if needs_upcasting:
+                self.vae.to(dtype=torch.float16)
+        else:
+            return StableDiffusionXLPipelineOutput(images=latents)
+
+        # apply watermark if available
+        if self.watermark is not None:
+            image = self.watermark.apply_watermark(image)
+
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+        if padding_mask_crop is not None:
+            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
From 1b47ff7e00d77822cdda353c09f60170de3805d7 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Thu, 16 Jan 2025 22:00:09 +0800
Subject: [PATCH 4/7] update Docs
Co-authored-by: Other Contributor 
---
 examples/community/README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/community/README.md b/examples/community/README.md
index 02777636d2d3..a62d49f0d2c0 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,7 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
 PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
 | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
 | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
-| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3)|
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
From 72ee35e4be022d1f8f287480ab31bfe09022b67a Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Sat, 18 Jan 2025 19:21:43 +0800
Subject: [PATCH 5/7] add Oral
Co-authored-by: Other Contributor 
---
 examples/community/README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/community/README.md b/examples/community/README.md
index a62d49f0d2c0..4c593a004893 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,7 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
 PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
 | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
 | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
-| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
From e5fc2e40e43bd132b70d1f236dc0c9cfbd6fd2bc Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Fri, 24 Jan 2025 20:46:31 +0800
Subject: [PATCH 6/7] update_review
Co-authored-by: Other Contributor 
---
 ...ne_stable_diffusion_xl_attentive_eraser.py | 96 ++++++++++++-------
 1 file changed, 64 insertions(+), 32 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index 48c01318991a..f78cedb9965a 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -81,27 +81,72 @@
     Examples:
         ```py
         >>> import torch
-        >>> from diffusers import StableDiffusionXLInpaintPipeline
+        >>> from diffusers import DDIMScheduler, DiffusionPipeline
         >>> from diffusers.utils import load_image
-
-        >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
-        ...     "stabilityai/stable-diffusion-xl-base-1.0",
-        ...     torch_dtype=torch.float16,
-        ...     variant="fp16",
-        ...     use_safetensors=True,
-        ... )
-        >>> pipe.to("cuda")
-
-        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
-        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
-
-        >>> init_image = load_image(img_url).convert("RGB")
-        >>> mask_image = load_image(mask_url).convert("RGB")
-
-        >>> prompt = "A majestic tiger sitting on a bench"
-        >>> image = pipe(
-        ...     prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+        >>> import torch.nn.functional as F
+        >>> from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+        >>> dtype = torch.float16
+        >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
+        >>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+
+        >>> pipeline = DiffusionPipeline.from_pretrained(
+        ...    "stabilityai/stable-diffusion-xl-base-1.0",
+        ...    custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+        ...    scheduler=scheduler,
+        ...    variant="fp16",
+        ...    use_safetensors=True,
+        ...    torch_dtype=dtype,
+        ... ).to(device)
+
+
+        >>> def preprocess_image(image_path, device):
+        ...     image = to_tensor((load_image(image_path)))
+        ...     image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+        ...     if image.shape[1] != 3:
+        ...         image = image.expand(-1, 3, -1, -1)
+        ...         image = F.interpolate(image, (1024, 1024))
+        ...         image = image.to(dtype).to(device)
+        ...         return image
+
+        >>> def preprocess_mask(mask_path, device):
+        ...     mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+        ...     mask = mask.unsqueeze_(0).float()  # 0 or 1
+        ...     mask = F.interpolate(mask, (1024, 1024))
+        ...     mask = gaussian_blur(mask, kernel_size=(77, 77))
+        ...     mask[mask < 0.1] = 0
+        ...     mask[mask >= 0.1] = 1
+        ...     mask = mask.to(dtype).to(device)
+        ...     return mask
+
+        >>> prompt = "" # Set prompt to null
+        >>> seed=123 
+        >>> generator = torch.Generator(device=device).manual_seed(seed)
+        >>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+        >>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+        >>> source_image = preprocess_image(source_image_path, device)
+        >>> mask = preprocess_mask(mask_path, device)
+
+        >>> image = pipeline(
+        ...     prompt=prompt, 
+        ...     image=source_image,
+        ...     mask_image=mask,
+        ...     height=1024,
+        ...     width=1024,
+        ...     AAS=True, # enable AAS
+        ...     strength=0.8, # inpainting strength
+        ...     rm_guidance_scale=9, # removal guidance scale
+        ...     ss_steps = 9, # similarity suppression steps
+        ...     ss_scale = 0.3, # similarity suppression scale
+        ...     AAS_start_step=0, # AAS start step
+        ...     AAS_start_layer=34, # AAS start layer
+        ...     AAS_end_layer=70, # AAS end layer
+        ...     num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+        ...     generator=generator,
+        ...     guidance_scale=1,
         ... ).images[0]
+        >>> image.save('./removed_img.png')
+        >>> print("Object removal completed")
         ```
 """
 
@@ -174,9 +219,6 @@ def __init__(
         self.mask = mask  # mask with shape (1, 1 ,h, w)
         self.ss_steps = ss_steps
         self.ss_scale = ss_scale
-        print("AAS at denoising steps: ", self.step_idx)
-        print("AAS at U-Net layers: ", self.layer_idx)
-        print("start AAS")
         self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
         self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
         self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
@@ -209,10 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
         Attention forward function
         """
         if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
-            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
-        # B = q.shape[0] // num_heads // 2
         H = int(np.sqrt(q.shape[1]))
-        # H = W = int(np.sqrt(q.shape[1]))
         if H == 16:
             mask = self.mask_16.to(sim.device)
         elif H == 32:
@@ -317,13 +356,6 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
             dimensions: ``batch x channels x height x width``.
     """
 
-    # checkpoint. TOD(Yiyi) - need to clean this up later
-    deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
-    deprecate(
-        "prepare_mask_and_masked_image",
-        "0.30.0",
-        deprecation_message,
-    )
     if image is None:
         raise ValueError("`image` input cannot be undefined.")
 
From bb52aeb717c7dd8d85e249fbf57564687e1f3333 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Fri, 24 Jan 2025 21:34:26 +0800
Subject: [PATCH 7/7] update_review_ms
Co-authored-by: Other Contributor 
---
 .../pipeline_stable_diffusion_xl_attentive_eraser.py       | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index f78cedb9965a..1269a69f0dc3 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -87,7 +87,7 @@
         >>> from torchvision.transforms.functional import to_tensor, gaussian_blur
 
         >>> dtype = torch.float16
-        >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
+        >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         >>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
 
         >>> pipeline = DiffusionPipeline.from_pretrained(
@@ -120,7 +120,7 @@
         ...     return mask
 
         >>> prompt = "" # Set prompt to null
-        >>> seed=123 
+        >>> seed=123
         >>> generator = torch.Generator(device=device).manual_seed(seed)
         >>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
         >>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
@@ -128,7 +128,7 @@
         >>> mask = preprocess_mask(mask_path, device)
 
         >>> image = pipeline(
-        ...     prompt=prompt, 
+        ...     prompt=prompt,
         ...     image=source_image,
         ...     mask_image=mask,
         ...     height=1024,
@@ -251,6 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
         Attention forward function
         """
         if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
         H = int(np.sqrt(q.shape[1]))
         if H == 16:
             mask = self.mask_16.to(sim.device)
+
+**Stable Diffusion XL Attentive Eraser Pipeline** is an advanced object removal pipeline that leverages SDXL for precise content suppression and seamless region completion. This pipeline uses **self-attention redirection guidance** to modify the model’s self-attention mechanism, allowing for effective removal and inpainting across various levels of mask precision, including semantic segmentation masks, bounding boxes, and hand-drawn masks. If you are interested in more detailed information and have any questions, please refer to the [paper](https://arxiv.org/abs/2412.12974) and [official implementation](https://github.com/Anonym0u3/AttentiveEraser).
+
+#### Key features
+
+- **Tuning-Free**: No additional training is required, making it easy to integrate and use.
+- **Flexible Mask Support**: Works with different types of masks for targeted object removal.
+- **High-Quality Results**: Utilizes the inherent generative power of diffusion models for realistic content completion.
+
+#### Usage example
+To use the Stable Diffusion XL Attentive Eraser Pipeline, you can initialize it as follows:
+```py
+import torch
+from diffusers import DDIMScheduler, DiffusionPipeline
+from diffusers.utils import load_image
+import torch.nn.functional as F
+from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+dtype = torch.float16
+device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
+
+scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+pipeline = DiffusionPipeline.from_pretrained(
+    "stabilityai/stable-diffusion-xl-base-1.0",
+    custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+    scheduler=scheduler,
+    variant="fp16",
+    use_safetensors=True,
+    torch_dtype=dtype,
+).to(device)
+
+
+def preprocess_image(image_path, device):
+    image = to_tensor((load_image(image_path)))
+    image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+    if image.shape[1] != 3:
+        image = image.expand(-1, 3, -1, -1)
+        image = F.interpolate(image, (1024, 1024))
+        image = image.to(dtype).to(device)
+        return image
+
+def preprocess_mask(mask_path, device):
+    mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+    mask = mask.unsqueeze_(0).float()  # 0 or 1
+    mask = F.interpolate(mask, (1024, 1024))
+    mask = gaussian_blur(mask, kernel_size=(77, 77))
+    mask[mask < 0.1] = 0
+    mask[mask >= 0.1] = 1
+    mask = mask.to(dtype).to(device)
+    return mask
+
+prompt = "" # Set prompt to null
+seed=123 
+generator = torch.Generator(device=device).manual_seed(seed)
+source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+source_image = preprocess_image(source_image_path, device)
+mask = preprocess_mask(mask_path, device)
+
+image = pipeline(
+    prompt=prompt, 
+    image=source_image,
+    mask_image=mask,
+    height=1024,
+    width=1024,
+    AAS=True, # enable AAS
+    strength=0.8, # inpainting strength
+    rm_guidance_scale=9, # removal guidance scale
+    ss_steps = 9, # similarity suppression steps
+    ss_scale = 0.3, # similarity suppression scale
+    AAS_start_step=0, # AAS start step
+    AAS_start_layer=34, # AAS start layer
+    AAS_end_layer=70, # AAS end layer
+    num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+    generator=generator,
+    guidance_scale=1,
+).images[0]
+image.save('./removed_img.png')
+print("Object removal completed")
+```
+
+| Source Image                                                                                   | Mask                                                                                        | Output                                                                                              |
+| ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------- |
+|  |  |  |
+
 # Perturbed-Attention Guidance
 
 [Project](https://ku-cvlab.github.io/Perturbed-Attention-Guidance/) / [arXiv](https://arxiv.org/abs/2403.17377) / [GitHub](https://github.com/KU-CVLAB/Perturbed-Attention-Guidance)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
new file mode 100644
index 000000000000..48c01318991a
--- /dev/null
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -0,0 +1,2285 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+from PIL import Image
+from transformers import (
+    CLIPImageProcessor,
+    CLIPTextModel,
+    CLIPTextModelWithProjection,
+    CLIPTokenizer,
+    CLIPVisionModelWithProjection,
+)
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import (
+    FromSingleFileMixin,
+    IPAdapterMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    TextualInversionLoaderMixin,
+)
+from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
+from diffusers.models.attention_processor import (
+    AttnProcessor2_0,
+    LoRAAttnProcessor2_0,
+    LoRAXFormersAttnProcessor,
+    XFormersAttnProcessor,
+)
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+    USE_PEFT_BACKEND,
+    deprecate,
+    is_invisible_watermark_available,
+    is_torch_xla_available,
+    logging,
+    replace_example_docstring,
+    scale_lora_layers,
+    unscale_lora_layers,
+)
+from diffusers.utils.torch_utils import randn_tensor
+
+
+if is_invisible_watermark_available():
+    from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
+
+if is_torch_xla_available():
+    import torch_xla.core.xla_model as xm
+
+    XLA_AVAILABLE = True
+else:
+    XLA_AVAILABLE = False
+
+
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, repeat
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+    Examples:
+        ```py
+        >>> import torch
+        >>> from diffusers import StableDiffusionXLInpaintPipeline
+        >>> from diffusers.utils import load_image
+
+        >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
+        ...     "stabilityai/stable-diffusion-xl-base-1.0",
+        ...     torch_dtype=torch.float16,
+        ...     variant="fp16",
+        ...     use_safetensors=True,
+        ... )
+        >>> pipe.to("cuda")
+
+        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+        >>> init_image = load_image(img_url).convert("RGB")
+        >>> mask_image = load_image(mask_url).convert("RGB")
+
+        >>> prompt = "A majestic tiger sitting on a bench"
+        >>> image = pipe(
+        ...     prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+        ... ).images[0]
+        ```
+"""
+
+
+class AttentionBase:
+    def __init__(self):
+        self.cur_step = 0
+        self.num_att_layers = -1
+        self.cur_att_layer = 0
+
+    def after_step(self):
+        pass
+
+    def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+        self.cur_att_layer += 1
+        if self.cur_att_layer == self.num_att_layers:
+            self.cur_att_layer = 0
+            self.cur_step += 1
+            # after step
+            self.after_step()
+        return out
+
+    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        out = torch.einsum("b i j, b j d -> b i d", attn, v)
+        out = rearrange(out, "(b h) n d -> b n (h d)", h=num_heads)
+        return out
+
+    def reset(self):
+        self.cur_step = 0
+        self.cur_att_layer = 0
+
+
+class AAS_XL(AttentionBase):
+    MODEL_TYPE = {"SD": 16, "SDXL": 70}
+
+    def __init__(
+        self,
+        start_step=4,
+        end_step=50,
+        start_layer=10,
+        end_layer=16,
+        layer_idx=None,
+        step_idx=None,
+        total_steps=50,
+        mask=None,
+        model_type="SD",
+        ss_steps=9,
+        ss_scale=1.0,
+    ):
+        """
+        Args:
+            start_step: the step to start AAS
+            start_layer: the layer to start AAS
+            layer_idx: list of the layers to apply AAS
+            step_idx: list the steps to apply AAS
+            total_steps: the total number of steps
+            mask: source mask with shape (h, w)
+            model_type: the model type, SD or SDXL
+        """
+        super().__init__()
+        self.total_steps = total_steps
+        self.total_layers = self.MODEL_TYPE.get(model_type, 16)
+        self.start_step = start_step
+        self.end_step = end_step
+        self.start_layer = start_layer
+        self.end_layer = end_layer
+        self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, end_layer))
+        self.step_idx = step_idx if step_idx is not None else list(range(start_step, end_step))
+        self.mask = mask  # mask with shape (1, 1 ,h, w)
+        self.ss_steps = ss_steps
+        self.ss_scale = ss_scale
+        print("AAS at denoising steps: ", self.step_idx)
+        print("AAS at U-Net layers: ", self.layer_idx)
+        print("start AAS")
+        self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
+        self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
+        self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
+        self.mask_128 = F.max_pool2d(mask, (1024 // 128, 1024 // 128)).round().squeeze().squeeze()
+
+    def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, is_mask_attn, mask, **kwargs):
+        B = q.shape[0] // num_heads
+        if is_mask_attn:
+            mask_flatten = mask.flatten(0)
+            if self.cur_step <= self.ss_steps:
+                # background
+                sim_bg = sim + mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+                # object
+                sim_fg = self.ss_scale * sim
+                sim_fg += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+                sim = torch.cat([sim_fg, sim_bg], dim=0)
+            else:
+                sim += mask_flatten.masked_fill(mask_flatten == 1, torch.finfo(sim.dtype).min)
+
+        attn = sim.softmax(-1)
+        if len(attn) == 2 * len(v):
+            v = torch.cat([v] * 2)
+        out = torch.einsum("h i j, h j d -> h i d", attn, v)
+        out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
+        return out
+
+    def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
+        """
+        Attention forward function
+        """
+        if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
+        # B = q.shape[0] // num_heads // 2
+        H = int(np.sqrt(q.shape[1]))
+        # H = W = int(np.sqrt(q.shape[1]))
+        if H == 16:
+            mask = self.mask_16.to(sim.device)
+        elif H == 32:
+            mask = self.mask_32.to(sim.device)
+        elif H == 64:
+            mask = self.mask_64.to(sim.device)
+        else:
+            mask = self.mask_128.to(sim.device)
+
+        q_wo, q_w = q.chunk(2)
+        k_wo, k_w = k.chunk(2)
+        v_wo, v_w = v.chunk(2)
+        sim_wo, sim_w = sim.chunk(2)
+        attn_wo, attn_w = attn.chunk(2)
+
+        out_source = self.attn_batch(
+            q_wo,
+            k_wo,
+            v_wo,
+            sim_wo,
+            attn_wo,
+            is_cross,
+            place_in_unet,
+            num_heads,
+            is_mask_attn=False,
+            mask=None,
+            **kwargs,
+        )
+        out_target = self.attn_batch(
+            q_w, k_w, v_w, sim_w, attn_w, is_cross, place_in_unet, num_heads, is_mask_attn=True, mask=mask, **kwargs
+        )
+
+        if self.mask is not None:
+            if out_target.shape[0] == 2:
+                out_target_fg, out_target_bg = out_target.chunk(2, 0)
+                mask = mask.reshape(-1, 1)  # (hw, 1)
+                out_target = out_target_fg * mask + out_target_bg * (1 - mask)
+            else:
+                out_target = out_target
+
+        out = torch.cat([out_source, out_target], dim=0)
+        return out
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
+def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
+    """
+    Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
+    Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
+    """
+    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
+    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
+    # rescale the results from guidance (fixes overexposure)
+    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
+    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
+    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
+    return noise_cfg
+
+
+def mask_pil_to_torch(mask, height, width):
+    # preprocess mask
+    if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+        mask = [mask]
+
+    if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+        mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+        mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+        mask = mask.astype(np.float32) / 255.0
+    elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+        mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+    mask = torch.from_numpy(mask)
+    return mask
+
+
+def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
+    """
+    Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+    converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+    ``image`` and ``1`` for the ``mask``.
+
+    The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+    binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+    Args:
+        image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+            It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+            ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+        mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+            It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+            ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+    Raises:
+        ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+        should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+        TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+            (ot the other way around).
+
+    Returns:
+        tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+            dimensions: ``batch x channels x height x width``.
+    """
+
+    # checkpoint. TOD(Yiyi) - need to clean this up later
+    deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
+    deprecate(
+        "prepare_mask_and_masked_image",
+        "0.30.0",
+        deprecation_message,
+    )
+    if image is None:
+        raise ValueError("`image` input cannot be undefined.")
+
+    if mask is None:
+        raise ValueError("`mask_image` input cannot be undefined.")
+
+    if isinstance(image, torch.Tensor):
+        if not isinstance(mask, torch.Tensor):
+            mask = mask_pil_to_torch(mask, height, width)
+
+        if image.ndim == 3:
+            image = image.unsqueeze(0)
+
+        # Batch and add channel dim for single mask
+        if mask.ndim == 2:
+            mask = mask.unsqueeze(0).unsqueeze(0)
+
+        # Batch single mask or add channel dim
+        if mask.ndim == 3:
+            # Single batched mask, no channel dim or single mask not batched but channel dim
+            if mask.shape[0] == 1:
+                mask = mask.unsqueeze(0)
+
+            # Batched masks no channel dim
+            else:
+                mask = mask.unsqueeze(1)
+
+        assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+        # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+        assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+        # Check image is in [-1, 1]
+        # if image.min() < -1 or image.max() > 1:
+        #    raise ValueError("Image should be in [-1, 1] range")
+
+        # Check mask is in [0, 1]
+        if mask.min() < 0 or mask.max() > 1:
+            raise ValueError("Mask should be in [0, 1] range")
+
+        # Binarize mask
+        mask[mask < 0.5] = 0
+        mask[mask >= 0.5] = 1
+
+        # Image as float32
+        image = image.to(dtype=torch.float32)
+    elif isinstance(mask, torch.Tensor):
+        raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+    else:
+        # preprocess image
+        if isinstance(image, (PIL.Image.Image, np.ndarray)):
+            image = [image]
+        if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+            # resize all images w.r.t passed height an width
+            image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+            image = [np.array(i.convert("RGB"))[None, :] for i in image]
+            image = np.concatenate(image, axis=0)
+        elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+            image = np.concatenate([i[None, :] for i in image], axis=0)
+
+        image = image.transpose(0, 3, 1, 2)
+        image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+        mask = mask_pil_to_torch(mask, height, width)
+        mask[mask < 0.5] = 0
+        mask[mask >= 0.5] = 1
+
+    if image.shape[1] == 4:
+        # images are in latent space and thus can't
+        # be masked set masked_image to None
+        # we assume that the checkpoint is not an inpainting
+        # checkpoint. TOD(Yiyi) - need to clean this up later
+        masked_image = None
+    else:
+        masked_image = image * (mask < 0.5)
+
+    # n.b. ensure backwards compatibility as old function does not return image
+    if return_image:
+        return mask, masked_image, image
+
+    return mask, masked_image
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
+def retrieve_latents(
+    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
+):
+    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
+        return encoder_output.latent_dist.sample(generator)
+    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
+        return encoder_output.latent_dist.mode()
+    elif hasattr(encoder_output, "latents"):
+        return encoder_output.latents
+    else:
+        raise AttributeError("Could not access latents of provided encoder_output")
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+    scheduler,
+    num_inference_steps: Optional[int] = None,
+    device: Optional[Union[str, torch.device]] = None,
+    timesteps: Optional[List[int]] = None,
+    **kwargs,
+):
+    """
+    Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+    custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+    Args:
+        scheduler (`SchedulerMixin`):
+            The scheduler to get timesteps from.
+        num_inference_steps (`int`):
+            The number of diffusion steps used when generating samples with a pre-trained model. If used,
+            `timesteps` must be `None`.
+        device (`str` or `torch.device`, *optional*):
+            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+        timesteps (`List[int]`, *optional*):
+                Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+                timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+                must be `None`.
+
+    Returns:
+        `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+        second element is the number of inference steps.
+    """
+    if timesteps is not None:
+        accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+        if not accepts_timesteps:
+            raise ValueError(
+                f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+                f" timestep schedules. Please check whether you are using the correct scheduler."
+            )
+        scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+        num_inference_steps = len(timesteps)
+    else:
+        scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+        timesteps = scheduler.timesteps
+    return timesteps, num_inference_steps
+
+
+class StableDiffusionXL_AE_Pipeline(
+    DiffusionPipeline,
+    StableDiffusionMixin,
+    TextualInversionLoaderMixin,
+    StableDiffusionXLLoraLoaderMixin,
+    FromSingleFileMixin,
+    IPAdapterMixin,
+):
+    r"""
+    Pipeline for object removal using Stable Diffusion XL.
+
+    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+    The pipeline also inherits the following loading methods:
+        - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+        - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+        - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+        - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+    Args:
+        vae ([`AutoencoderKL`]):
+            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+        text_encoder ([`CLIPTextModel`]):
+            Frozen text-encoder. Stable Diffusion XL uses the text portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+        text_encoder_2 ([` CLIPTextModelWithProjection`]):
+            Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
+            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
+            specifically the
+            [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
+            variant.
+        tokenizer (`CLIPTokenizer`):
+            Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        tokenizer_2 (`CLIPTokenizer`):
+            Second Tokenizer of class
+            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+        scheduler ([`SchedulerMixin`]):
+            A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+        requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
+            Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
+            of `stabilityai/stable-diffusion-xl-refiner-1-0`.
+        force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
+            Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
+            `stabilityai/stable-diffusion-xl-base-1-0`.
+        add_watermarker (`bool`, *optional*):
+            Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
+            watermark output images. If not defined, it will default to True if the package is installed, otherwise no
+            watermarker will be used.
+    """
+
+    model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
+
+    _optional_components = [
+        "tokenizer",
+        "tokenizer_2",
+        "text_encoder",
+        "text_encoder_2",
+        "image_encoder",
+        "feature_extractor",
+    ]
+    _callback_tensor_inputs = [
+        "latents",
+        "prompt_embeds",
+        "negative_prompt_embeds",
+        "add_text_embeds",
+        "add_time_ids",
+        "negative_pooled_prompt_embeds",
+        "add_neg_time_ids",
+        "mask",
+        "masked_image_latents",
+    ]
+
+    def __init__(
+        self,
+        vae: AutoencoderKL,
+        text_encoder: CLIPTextModel,
+        text_encoder_2: CLIPTextModelWithProjection,
+        tokenizer: CLIPTokenizer,
+        tokenizer_2: CLIPTokenizer,
+        unet: UNet2DConditionModel,
+        scheduler: KarrasDiffusionSchedulers,
+        image_encoder: CLIPVisionModelWithProjection = None,
+        feature_extractor: CLIPImageProcessor = None,
+        requires_aesthetics_score: bool = False,
+        force_zeros_for_empty_prompt: bool = True,
+        add_watermarker: Optional[bool] = None,
+    ):
+        super().__init__()
+
+        self.register_modules(
+            vae=vae,
+            text_encoder=text_encoder,
+            text_encoder_2=text_encoder_2,
+            tokenizer=tokenizer,
+            tokenizer_2=tokenizer_2,
+            unet=unet,
+            image_encoder=image_encoder,
+            feature_extractor=feature_extractor,
+            scheduler=scheduler,
+        )
+        self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
+        self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
+        self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+        self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+        self.mask_processor = VaeImageProcessor(
+            vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
+        )
+
+        add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
+
+        if add_watermarker:
+            self.watermark = StableDiffusionXLWatermarker()
+        else:
+            self.watermark = None
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+    def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+        dtype = next(self.image_encoder.parameters()).dtype
+
+        if not isinstance(image, torch.Tensor):
+            image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+        image = image.to(device=device, dtype=dtype)
+        if output_hidden_states:
+            image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+            image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_enc_hidden_states = self.image_encoder(
+                torch.zeros_like(image), output_hidden_states=True
+            ).hidden_states[-2]
+            uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+                num_images_per_prompt, dim=0
+            )
+            return image_enc_hidden_states, uncond_image_enc_hidden_states
+        else:
+            image_embeds = self.image_encoder(image).image_embeds
+            image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+            uncond_image_embeds = torch.zeros_like(image_embeds)
+
+            return image_embeds, uncond_image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+    def prepare_ip_adapter_image_embeds(
+        self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+    ):
+        if ip_adapter_image_embeds is None:
+            if not isinstance(ip_adapter_image, list):
+                ip_adapter_image = [ip_adapter_image]
+
+            if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+                raise ValueError(
+                    f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+                )
+
+            image_embeds = []
+            for single_ip_adapter_image, image_proj_layer in zip(
+                ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+            ):
+                output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+                single_image_embeds, single_negative_image_embeds = self.encode_image(
+                    single_ip_adapter_image, device, 1, output_hidden_state
+                )
+                single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+                single_negative_image_embeds = torch.stack(
+                    [single_negative_image_embeds] * num_images_per_prompt, dim=0
+                )
+
+                if do_classifier_free_guidance:
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                    single_image_embeds = single_image_embeds.to(device)
+
+                image_embeds.append(single_image_embeds)
+        else:
+            repeat_dims = [1]
+            image_embeds = []
+            for single_image_embeds in ip_adapter_image_embeds:
+                if do_classifier_free_guidance:
+                    single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                    single_negative_image_embeds = single_negative_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+                    )
+                    single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+                else:
+                    single_image_embeds = single_image_embeds.repeat(
+                        num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+                    )
+                image_embeds.append(single_image_embeds)
+
+        return image_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
+    def encode_prompt(
+        self,
+        prompt: str,
+        prompt_2: Optional[str] = None,
+        device: Optional[torch.device] = None,
+        num_images_per_prompt: int = 1,
+        do_classifier_free_guidance: bool = True,
+        negative_prompt: Optional[str] = None,
+        negative_prompt_2: Optional[str] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        lora_scale: Optional[float] = None,
+        clip_skip: Optional[int] = None,
+    ):
+        r"""
+        Encodes the prompt into text encoder hidden states.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                prompt to be encoded
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            device: (`torch.device`):
+                torch device
+            num_images_per_prompt (`int`):
+                number of images that should be generated per prompt
+            do_classifier_free_guidance (`bool`):
+                whether to use classifier free guidance or not
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            lora_scale (`float`, *optional*):
+                A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+        """
+        device = device or self._execution_device
+
+        # set lora scale so that monkey patched LoRA
+        # function of text encoder can correctly access it
+        if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
+            self._lora_scale = lora_scale
+
+            # dynamically adjust the LoRA scale
+            if self.text_encoder is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder, lora_scale)
+
+            if self.text_encoder_2 is not None:
+                if not USE_PEFT_BACKEND:
+                    adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
+                else:
+                    scale_lora_layers(self.text_encoder_2, lora_scale)
+
+        prompt = [prompt] if isinstance(prompt, str) else prompt
+
+        if prompt is not None:
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        if prompt_embeds is None:
+            prompt_2 = prompt_2 or prompt
+            prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+            # textual inversion: process multi-vector tokens if necessary
+            prompt_embeds_list = []
+            prompts = [prompt, prompt_2]
+            for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+                text_inputs = tokenizer(
+                    prompt,
+                    padding="max_length",
+                    max_length=tokenizer.model_max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                text_input_ids = text_inputs.input_ids
+                untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+                if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                    text_input_ids, untruncated_ids
+                ):
+                    removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                    logger.warning(
+                        "The following part of your input was truncated because CLIP can only handle sequences up to"
+                        f" {tokenizer.model_max_length} tokens: {removed_text}"
+                    )
+
+                prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
+
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                pooled_prompt_embeds = prompt_embeds[0]
+                if clip_skip is None:
+                    prompt_embeds = prompt_embeds.hidden_states[-2]
+                else:
+                    # "2" because SDXL always indexes from the penultimate layer.
+                    prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
+
+                prompt_embeds_list.append(prompt_embeds)
+
+            prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+
+        # get unconditional embeddings for classifier free guidance
+        zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
+        if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
+            negative_prompt_embeds = torch.zeros_like(prompt_embeds)
+            negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
+        elif do_classifier_free_guidance and negative_prompt_embeds is None:
+            negative_prompt = negative_prompt or ""
+            negative_prompt_2 = negative_prompt_2 or negative_prompt
+
+            # normalize str to list
+            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
+            negative_prompt_2 = (
+                batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
+            )
+
+            uncond_tokens: List[str]
+            if prompt is not None and type(prompt) is not type(negative_prompt):
+                raise TypeError(
+                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+                    f" {type(prompt)}."
+                )
+            elif batch_size != len(negative_prompt):
+                raise ValueError(
+                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+                    " the batch size of `prompt`."
+                )
+            else:
+                uncond_tokens = [negative_prompt, negative_prompt_2]
+
+            negative_prompt_embeds_list = []
+            for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
+                if isinstance(self, TextualInversionLoaderMixin):
+                    negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
+
+                max_length = prompt_embeds.shape[1]
+                uncond_input = tokenizer(
+                    negative_prompt,
+                    padding="max_length",
+                    max_length=max_length,
+                    truncation=True,
+                    return_tensors="pt",
+                )
+
+                negative_prompt_embeds = text_encoder(
+                    uncond_input.input_ids.to(device),
+                    output_hidden_states=True,
+                )
+                # We are only ALWAYS interested in the pooled output of the final text encoder
+                negative_pooled_prompt_embeds = negative_prompt_embeds[0]
+                negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
+
+                negative_prompt_embeds_list.append(negative_prompt_embeds)
+
+            negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
+
+        if self.text_encoder_2 is not None:
+            prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+        else:
+            prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+        bs_embed, seq_len, _ = prompt_embeds.shape
+        # duplicate text embeddings for each generation per prompt, using mps friendly method
+        prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+        prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+        if do_classifier_free_guidance:
+            # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+            seq_len = negative_prompt_embeds.shape[1]
+
+            if self.text_encoder_2 is not None:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
+            else:
+                negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
+
+            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+        pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+            bs_embed * num_images_per_prompt, -1
+        )
+        if do_classifier_free_guidance:
+            negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
+                bs_embed * num_images_per_prompt, -1
+            )
+
+        if self.text_encoder is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder, lora_scale)
+
+        if self.text_encoder_2 is not None:
+            if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
+                # Retrieve the original scale by scaling back the LoRA layers
+                unscale_lora_layers(self.text_encoder_2, lora_scale)
+
+        return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+    def prepare_extra_step_kwargs(self, generator, eta):
+        # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+        # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+        # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+        # and should be between [0, 1]
+
+        accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        extra_step_kwargs = {}
+        if accepts_eta:
+            extra_step_kwargs["eta"] = eta
+
+        # check if the scheduler accepts generator
+        accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+        if accepts_generator:
+            extra_step_kwargs["generator"] = generator
+        return extra_step_kwargs
+
+    def check_inputs(
+        self,
+        prompt,
+        prompt_2,
+        image,
+        mask_image,
+        height,
+        width,
+        strength,
+        callback_steps,
+        output_type,
+        negative_prompt=None,
+        negative_prompt_2=None,
+        prompt_embeds=None,
+        negative_prompt_embeds=None,
+        ip_adapter_image=None,
+        ip_adapter_image_embeds=None,
+        callback_on_step_end_tensor_inputs=None,
+        padding_mask_crop=None,
+    ):
+        if strength < 0 or strength > 1:
+            raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+
+        if height % 8 != 0 or width % 8 != 0:
+            raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+        if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+            raise ValueError(
+                f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+                f" {type(callback_steps)}."
+            )
+
+        if callback_on_step_end_tensor_inputs is not None and not all(
+            k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+        ):
+            raise ValueError(
+                f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+            )
+
+        if prompt is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt_2 is not None and prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+                " only forward one of the two."
+            )
+        elif prompt is None and prompt_embeds is None:
+            raise ValueError(
+                "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+            )
+        elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+        elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
+            raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
+
+        if negative_prompt is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+        elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
+            raise ValueError(
+                f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
+                f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+            )
+
+        if prompt_embeds is not None and negative_prompt_embeds is not None:
+            if prompt_embeds.shape != negative_prompt_embeds.shape:
+                raise ValueError(
+                    "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+                    f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+                    f" {negative_prompt_embeds.shape}."
+                )
+        if padding_mask_crop is not None:
+            if not isinstance(image, PIL.Image.Image):
+                raise ValueError(
+                    f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
+                )
+            if not isinstance(mask_image, PIL.Image.Image):
+                raise ValueError(
+                    f"The mask image should be a PIL image when inpainting mask crop, but is of type"
+                    f" {type(mask_image)}."
+                )
+            if output_type != "pil":
+                raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
+
+        if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+            raise ValueError(
+                "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+            )
+
+        if ip_adapter_image_embeds is not None:
+            if not isinstance(ip_adapter_image_embeds, list):
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+                )
+            elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+                raise ValueError(
+                    f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+                )
+
+    def prepare_latents(
+        self,
+        batch_size,
+        num_channels_latents,
+        height,
+        width,
+        dtype,
+        device,
+        generator,
+        latents=None,
+        image=None,
+        timestep=None,
+        is_strength_max=True,
+        add_noise=True,
+        return_noise=False,
+        return_image_latents=False,
+    ):
+        shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+        if isinstance(generator, list) and len(generator) != batch_size:
+            raise ValueError(
+                f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+                f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+            )
+
+        if (image is None or timestep is None) and not is_strength_max:
+            raise ValueError(
+                "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+                "However, either the image or the noise timestep has not been provided."
+            )
+
+        if image.shape[1] == 4:
+            image_latents = image.to(device=device, dtype=dtype)
+            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+        elif return_image_latents or (latents is None and not is_strength_max):
+            image = image.to(device=device, dtype=dtype)
+            image_latents = self._encode_vae_image(image=image, generator=generator)
+            image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
+
+        if latents is None and add_noise:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            # if strength is 1. then initialise the latents to noise, else initial to image + noise
+            latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
+            # if pure noise then scale the initial latents by the  Scheduler's init sigma
+            latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
+        elif add_noise:
+            noise = latents.to(device)
+            latents = noise * self.scheduler.init_noise_sigma
+        else:
+            noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+            latents = image_latents.to(device)
+
+        outputs = (latents,)
+
+        if return_noise:
+            outputs += (noise,)
+
+        if return_image_latents:
+            outputs += (image_latents,)
+
+        return outputs
+
+    def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+        dtype = image.dtype
+        if self.vae.config.force_upcast:
+            image = image.float()
+            self.vae.to(dtype=torch.float32)
+
+        if isinstance(generator, list):
+            image_latents = [
+                retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
+                for i in range(image.shape[0])
+            ]
+            image_latents = torch.cat(image_latents, dim=0)
+        else:
+            image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
+
+        if self.vae.config.force_upcast:
+            self.vae.to(dtype)
+
+        image_latents = image_latents.to(dtype)
+        image_latents = self.vae.config.scaling_factor * image_latents
+
+        return image_latents
+
+    def prepare_mask_latents(
+        self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
+    ):
+        # resize the mask to latents shape as we concatenate the mask to the latents
+        # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+        # and half precision
+        # mask = torch.nn.functional.interpolate(
+        #    mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+        # )
+        mask = torch.nn.functional.max_pool2d(mask, (8, 8)).round()
+        mask = mask.to(device=device, dtype=dtype)
+
+        # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+        if mask.shape[0] < batch_size:
+            if not batch_size % mask.shape[0] == 0:
+                raise ValueError(
+                    "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+                    f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+                    " of masks that you pass is divisible by the total requested batch size."
+                )
+            mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+
+        mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+
+        if masked_image is not None and masked_image.shape[1] == 4:
+            masked_image_latents = masked_image
+        else:
+            masked_image_latents = None
+
+        if masked_image is not None:
+            if masked_image_latents is None:
+                masked_image = masked_image.to(device=device, dtype=dtype)
+                masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+            if masked_image_latents.shape[0] < batch_size:
+                if not batch_size % masked_image_latents.shape[0] == 0:
+                    raise ValueError(
+                        "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+                        f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+                        " Make sure the number of images that you pass is divisible by the total requested batch size."
+                    )
+                masked_image_latents = masked_image_latents.repeat(
+                    batch_size // masked_image_latents.shape[0], 1, 1, 1
+                )
+
+            masked_image_latents = (
+                torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
+            )
+
+            # aligning device to prevent device errors when concating it with the latent model input
+            masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+
+        return mask, masked_image_latents
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
+    def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
+        # get the original timestep using init_timestep
+        if denoising_start is None:
+            init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+            t_start = max(num_inference_steps - init_timestep, 0)
+        else:
+            t_start = 0
+
+        timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+        # Strength is irrelevant if we directly request a timestep to start at;
+        # that is, strength is determined by the denoising_start instead.
+        if denoising_start is not None:
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (denoising_start * self.scheduler.config.num_train_timesteps)
+                )
+            )
+
+            num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
+            if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
+                # if the scheduler is a 2nd order scheduler we might have to do +1
+                # because `num_inference_steps` might be even given that every timestep
+                # (except the highest one) is duplicated. If `num_inference_steps` is even it would
+                # mean that we cut the timesteps in the middle of the denoising step
+                # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1
+                # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
+                num_inference_steps = num_inference_steps + 1
+
+            # because t_n+1 >= t_n, we slice the timesteps starting from the end
+            timesteps = timesteps[-num_inference_steps:]
+            return timesteps, num_inference_steps
+
+        return timesteps, num_inference_steps - t_start
+
+    # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
+    def _get_add_time_ids(
+        self,
+        original_size,
+        crops_coords_top_left,
+        target_size,
+        aesthetic_score,
+        negative_aesthetic_score,
+        negative_original_size,
+        negative_crops_coords_top_left,
+        negative_target_size,
+        dtype,
+        text_encoder_projection_dim=None,
+    ):
+        if self.config.requires_aesthetics_score:
+            add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
+            add_neg_time_ids = list(
+                negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
+            )
+        else:
+            add_time_ids = list(original_size + crops_coords_top_left + target_size)
+            add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
+
+        passed_add_embed_dim = (
+            self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
+        )
+        expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
+
+        if (
+            expected_add_embed_dim > passed_add_embed_dim
+            and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
+            )
+        elif (
+            expected_add_embed_dim < passed_add_embed_dim
+            and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
+        ):
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
+            )
+        elif expected_add_embed_dim != passed_add_embed_dim:
+            raise ValueError(
+                f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
+            )
+
+        add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
+        add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
+
+        return add_time_ids, add_neg_time_ids
+
+    # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
+    def upcast_vae(self):
+        dtype = self.vae.dtype
+        self.vae.to(dtype=torch.float32)
+        use_torch_2_0_or_xformers = isinstance(
+            self.vae.decoder.mid_block.attentions[0].processor,
+            (
+                AttnProcessor2_0,
+                XFormersAttnProcessor,
+                LoRAXFormersAttnProcessor,
+                LoRAAttnProcessor2_0,
+            ),
+        )
+        # if xformers or torch_2_0 is used attention block does not need
+        # to be in float32 which can save lots of memory
+        if use_torch_2_0_or_xformers:
+            self.vae.post_quant_conv.to(dtype)
+            self.vae.decoder.conv_in.to(dtype)
+            self.vae.decoder.mid_block.to(dtype)
+
+    # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+    def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+        """
+        See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+        Args:
+            timesteps (`torch.Tensor`):
+                generate embedding vectors at these timesteps
+            embedding_dim (`int`, *optional*, defaults to 512):
+                dimension of the embeddings to generate
+            dtype:
+                data type of the generated embeddings
+
+        Returns:
+            `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+        """
+        assert len(w.shape) == 1
+        w = w * 1000.0
+
+        half_dim = embedding_dim // 2
+        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+        emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+        emb = w.to(dtype)[:, None] * emb[None, :]
+        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+        if embedding_dim % 2 == 1:  # zero pad
+            emb = torch.nn.functional.pad(emb, (0, 1))
+        assert emb.shape == (w.shape[0], embedding_dim)
+        return emb
+
+    @property
+    def guidance_scale(self):
+        return self._guidance_scale
+
+    @property
+    def guidance_rescale(self):
+        return self._guidance_rescale
+
+    @property
+    def clip_skip(self):
+        return self._clip_skip
+
+    @property
+    def do_self_attention_redirection_guidance(self):  # SARG
+        return self._rm_guidance_scale > 1 and self._AAS
+
+    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+    # corresponds to doing no classifier free guidance.
+    @property
+    def do_classifier_free_guidance(self):
+        return (
+            self._guidance_scale > 1
+            and self.unet.config.time_cond_proj_dim is None
+            and not self.do_self_attention_redirection_guidance
+        )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+
+    @property
+    def cross_attention_kwargs(self):
+        return self._cross_attention_kwargs
+
+    @property
+    def denoising_end(self):
+        return self._denoising_end
+
+    @property
+    def denoising_start(self):
+        return self._denoising_start
+
+    @property
+    def num_timesteps(self):
+        return self._num_timesteps
+
+    @property
+    def interrupt(self):
+        return self._interrupt
+
+    @torch.no_grad()
+    def image2latent(self, image: torch.Tensor, generator: torch.Generator):
+        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        if type(image) is Image:
+            image = np.array(image)
+            image = torch.from_numpy(image).float() / 127.5 - 1
+            image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
+        # input image density range [-1, 1]
+        # latents = self.vae.encode(image)['latent_dist'].mean
+        latents = self._encode_vae_image(image, generator)
+        # latents = retrieve_latents(self.vae.encode(image))
+        # latents = latents * self.vae.config.scaling_factor
+        return latents
+
+    def next_step(self, model_output: torch.FloatTensor, timestep: int, x: torch.FloatTensor, eta=0.0, verbose=False):
+        """
+        Inverse sampling for DDIM Inversion
+        """
+        if verbose:
+            print("timestep: ", timestep)
+        next_step = timestep
+        timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
+        alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
+        beta_prod_t = 1 - alpha_prod_t
+        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+        pred_dir = (1 - alpha_prod_t_next) ** 0.5 * model_output
+        x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
+        return x_next, pred_x0
+
+    @torch.no_grad()
+    def invert(
+        self,
+        image: torch.Tensor,
+        prompt,
+        num_inference_steps=50,
+        eta=0.0,
+        original_size: Tuple[int, int] = None,
+        target_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        return_intermediates=False,
+        **kwds,
+    ):
+        """
+        invert a real image into noise map with determinisc DDIM inversion
+        """
+        DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        batch_size = image.shape[0]
+        if isinstance(prompt, list):
+            if batch_size == 1:
+                image = image.expand(len(prompt), -1, -1, -1)
+        elif isinstance(prompt, str):
+            if batch_size > 1:
+                prompt = [prompt] * batch_size
+
+        # Define tokenizers and text encoders
+        tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
+        text_encoders = (
+            [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
+        )
+
+        prompt_2 = prompt
+        prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
+
+        # textual inversion: process multi-vector tokens if necessary
+        prompt_embeds_list = []
+        prompts = [prompt, prompt_2]
+        for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
+            if isinstance(self, TextualInversionLoaderMixin):
+                prompt = self.maybe_convert_prompt(prompt, tokenizer)
+
+            text_inputs = tokenizer(
+                prompt,
+                padding="max_length",
+                max_length=tokenizer.model_max_length,
+                truncation=True,
+                return_tensors="pt",
+            )
+
+            text_input_ids = text_inputs.input_ids
+            untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+            if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+                text_input_ids, untruncated_ids
+            ):
+                removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
+                logger.warning(
+                    "The following part of your input was truncated because CLIP can only handle sequences up to"
+                    f" {tokenizer.model_max_length} tokens: {removed_text}"
+                )
+
+            prompt_embeds = text_encoder(text_input_ids.to(DEVICE), output_hidden_states=True)
+
+            # We are only ALWAYS interested in the pooled output of the final text encoder
+            pooled_prompt_embeds = prompt_embeds[0]
+            prompt_embeds = prompt_embeds.hidden_states[-2]
+            prompt_embeds_list.append(prompt_embeds)
+        prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
+        prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=DEVICE)
+
+        # define initial latents
+        latents = self.image2latent(image, generator=None)
+
+        start_latents = latents
+        height, width = latents.shape[-2:]
+        height = height * self.vae_scale_factor
+        width = width * self.vae_scale_factor
+
+        original_size = (height, width)
+        target_size = (height, width)
+        negative_original_size = original_size
+        negative_target_size = target_size
+
+        add_text_embeds = pooled_prompt_embeds
+        text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+
+        add_time_ids = add_time_ids.repeat(batch_size, 1).to(DEVICE)
+
+        # interative sampling
+        self.scheduler.set_timesteps(num_inference_steps)
+        latents_list = [latents]
+        pred_x0_list = []
+        # for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
+        for i, t in enumerate(reversed(self.scheduler.timesteps)):
+            model_inputs = latents
+
+            # predict the noise
+            added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+            noise_pred = self.unet(
+                model_inputs, t, encoder_hidden_states=prompt_embeds, added_cond_kwargs=added_cond_kwargs
+            ).sample
+
+            # compute the previous noise sample x_t-1 -> x_t
+            latents, pred_x0 = self.next_step(noise_pred, t, latents)
+            """
+            if t >= 1 and t < 41:
+                latents, pred_x0 = self.next_step_degrade(noise_pred, t, latents, mask)
+            else:
+                latents, pred_x0 = self.next_step(noise_pred, t, latents) """
+
+            latents_list.append(latents)
+            pred_x0_list.append(pred_x0)
+
+        if return_intermediates:
+            # return the intermediate laters during inversion
+            # pred_x0_list = [self.latent2image(img, return_type="np") for img in pred_x0_list]
+            # latents_list = [self.latent2image(img, return_type="np") for img in latents_list]
+            return latents, latents_list, pred_x0_list
+        return latents, start_latents
+
+    def opt(
+        self,
+        model_output: torch.FloatTensor,
+        timestep: int,
+        x: torch.FloatTensor,
+    ):
+        """
+        predict the sampe the next step in the denoise process.
+        """
+        ref_noise = model_output[:1, :, :, :].expand(model_output.shape)
+        alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
+        beta_prod_t = 1 - alpha_prod_t
+        pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
+        x_opt = alpha_prod_t**0.5 * pred_x0 + (1 - alpha_prod_t) ** 0.5 * ref_noise
+        return x_opt, pred_x0
+
+    def regiter_attention_editor_diffusers(self, unet, editor: AttentionBase):
+        """
+        Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
+        """
+
+        def ca_forward(self, place_in_unet):
+            def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
+                """
+                The attention is similar to the original implementation of LDM CrossAttention class
+                except adding some modifications on the attention
+                """
+                if encoder_hidden_states is not None:
+                    context = encoder_hidden_states
+                if attention_mask is not None:
+                    mask = attention_mask
+
+                to_out = self.to_out
+                if isinstance(to_out, nn.modules.container.ModuleList):
+                    to_out = self.to_out[0]
+                else:
+                    to_out = self.to_out
+
+                h = self.heads
+                q = self.to_q(x)
+                is_cross = context is not None
+                context = context if is_cross else x
+                k = self.to_k(context)
+                v = self.to_v(context)
+                # q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+                q, k, v = (rearrange(t, "b n (h d) -> (b h) n d", h=h) for t in (q, k, v))
+
+                sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
+
+                if mask is not None:
+                    mask = rearrange(mask, "b ... -> b (...)")
+                    max_neg_value = -torch.finfo(sim.dtype).max
+                    mask = repeat(mask, "b j -> (b h) () j", h=h)
+                    mask = mask[:, None, :].repeat(h, 1, 1)
+                    sim.masked_fill_(~mask, max_neg_value)
+
+                attn = sim.softmax(dim=-1)
+                # the only difference
+                out = editor(q, k, v, sim, attn, is_cross, place_in_unet, self.heads, scale=self.scale)
+
+                return to_out(out)
+
+            return forward
+
+        def register_editor(net, count, place_in_unet):
+            for name, subnet in net.named_children():
+                if net.__class__.__name__ == "Attention":  # spatial Transformer layer
+                    net.forward = ca_forward(net, place_in_unet)
+                    return count + 1
+                elif hasattr(net, "children"):
+                    count = register_editor(subnet, count, place_in_unet)
+            return count
+
+        cross_att_count = 0
+        for net_name, net in unet.named_children():
+            if "down" in net_name:
+                cross_att_count += register_editor(net, 0, "down")
+            elif "mid" in net_name:
+                cross_att_count += register_editor(net, 0, "mid")
+            elif "up" in net_name:
+                cross_att_count += register_editor(net, 0, "up")
+        editor.num_att_layers = cross_att_count
+
+    @torch.no_grad()
+    @replace_example_docstring(EXAMPLE_DOC_STRING)
+    def __call__(
+        self,
+        prompt: Union[str, List[str]] = None,
+        prompt_2: Optional[Union[str, List[str]]] = None,
+        image: PipelineImageInput = None,
+        mask_image: PipelineImageInput = None,
+        masked_image_latents: torch.FloatTensor = None,
+        height: Optional[int] = None,
+        width: Optional[int] = None,
+        padding_mask_crop: Optional[int] = None,
+        strength: float = 0.9999,
+        AAS: bool = True,  # AE parameter
+        rm_guidance_scale: float = 7.0,  # AE parameter
+        ss_steps: int = 9,  # AE parameter
+        ss_scale: float = 0.3,  # AE parameter
+        AAS_start_step: int = 0,  # AE parameter
+        AAS_start_layer: int = 34,  # AE parameter
+        AAS_end_layer: int = 70,  # AE parameter
+        num_inference_steps: int = 50,
+        timesteps: List[int] = None,
+        denoising_start: Optional[float] = None,
+        denoising_end: Optional[float] = None,
+        guidance_scale: float = 7.5,
+        negative_prompt: Optional[Union[str, List[str]]] = None,
+        negative_prompt_2: Optional[Union[str, List[str]]] = None,
+        num_images_per_prompt: Optional[int] = 1,
+        eta: float = 0.0,
+        generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+        latents: Optional[torch.FloatTensor] = None,
+        prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+        pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
+        ip_adapter_image: Optional[PipelineImageInput] = None,
+        ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+        output_type: Optional[str] = "pil",
+        return_dict: bool = True,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        guidance_rescale: float = 0.0,
+        original_size: Tuple[int, int] = None,
+        crops_coords_top_left: Tuple[int, int] = (0, 0),
+        target_size: Tuple[int, int] = None,
+        negative_original_size: Optional[Tuple[int, int]] = None,
+        negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
+        negative_target_size: Optional[Tuple[int, int]] = None,
+        aesthetic_score: float = 6.0,
+        negative_aesthetic_score: float = 2.5,
+        clip_skip: Optional[int] = None,
+        callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+        callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+        **kwargs,
+    ):
+        r"""
+        Function invoked when calling the pipeline for generation.
+
+        Args:
+            prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+                instead.
+            prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
+                used in both text-encoders
+            image (`PIL.Image.Image`):
+                `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
+                be masked out with `mask_image` and repainted according to `prompt`.
+            mask_image (`PIL.Image.Image`):
+                `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
+                repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
+                to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
+                instead of 3, so the expected shape would be `(B, H, W, 1)`.
+            height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The height in pixels of the generated image. This is set to 1024 by default for the best results.
+                Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+                The width in pixels of the generated image. This is set to 1024 by default for the best results.
+                Anything below 512 pixels won't work well for
+                [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
+                and checkpoints that are not specifically fine-tuned on low resolutions.
+            padding_mask_crop (`int`, *optional*, defaults to `None`):
+                The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to image and mask_image. If
+                `padding_mask_crop` is not `None`, it will first find a rectangular region with the same aspect ration of the image and
+                contains all masked area, and then expand that area based on `padding_mask_crop`. The image and mask_image will then be cropped based on
+                the expanded area before resizing to the original image size for inpainting. This is useful when the masked area is small while the image is large
+                and contain information inreleant for inpainging, such as background.
+            strength (`float`, *optional*, defaults to 0.9999):
+                Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
+                between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
+                `strength`. The number of denoising steps depends on the amount of noise initially added. When
+                `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
+                iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
+                portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
+                integer, the value of `strength` will be ignored.
+            num_inference_steps (`int`, *optional*, defaults to 50):
+                The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+                expense of slower inference.
+            timesteps (`List[int]`, *optional*):
+                Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+                in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+                passed will be used. Must be in descending order.
+            denoising_start (`float`, *optional*):
+                When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
+                bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
+                it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
+                strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
+                is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+            denoising_end (`float`, *optional*):
+                When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
+                completed before it is intentionally prematurely terminated. As a result, the returned sample will
+                still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
+                denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
+                final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
+                forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
+                Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
+            guidance_scale (`float`, *optional*, defaults to 7.5):
+                Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+                `guidance_scale` is defined as `w` of equation 2. of [Imagen
+                Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+                1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+                usually at the expense of lower image quality.
+            negative_prompt (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation. If not defined, one has to pass
+                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+                less than `1`).
+            negative_prompt_2 (`str` or `List[str]`, *optional*):
+                The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
+                `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
+            prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+                provided, text embeddings will be generated from `prompt` input argument.
+            negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+                argument.
+            pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
+                If not provided, pooled text embeddings will be generated from `prompt` input argument.
+            negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
+                Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+                weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
+                input argument.
+            ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+            ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+                Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+                Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+                if `do_classifier_free_guidance` is set to `True`.
+                If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+            num_images_per_prompt (`int`, *optional*, defaults to 1):
+                The number of images to generate per prompt.
+            eta (`float`, *optional*, defaults to 0.0):
+                Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+                [`schedulers.DDIMScheduler`], will be ignored for others.
+            generator (`torch.Generator`, *optional*):
+                One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+                to make generation deterministic.
+            latents (`torch.FloatTensor`, *optional*):
+                Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+                generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+                tensor will ge generated by sampling using the supplied random `generator`.
+            output_type (`str`, *optional*, defaults to `"pil"`):
+                The output format of the generate image. Choose between
+                [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+                plain tuple.
+            cross_attention_kwargs (`dict`, *optional*):
+                A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+                `self.processor` in
+                [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+            original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
+                `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
+                explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
+                `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
+                `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                For most cases, `target_size` should be set to the desired height and width of the generated image. If
+                not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
+                section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a specific image resolution. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
+                To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
+                micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
+                To negatively condition the generation process based on a target image resolution. It should be as same
+                as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
+                information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
+            aesthetic_score (`float`, *optional*, defaults to 6.0):
+                Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
+            negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
+                Part of SDXL's micro-conditioning as explained in section 2.2 of
+                [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
+                simulate an aesthetic score of the generated image by influencing the negative text condition.
+            clip_skip (`int`, *optional*):
+                Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+                the output of the pre-final layer will be used for computing the prompt embeddings.
+            callback_on_step_end (`Callable`, *optional*):
+                A function that calls at the end of each denoising steps during the inference. The function is called
+                with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+                callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+                `callback_on_step_end_tensor_inputs`.
+            callback_on_step_end_tensor_inputs (`List`, *optional*):
+                The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+                will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+                `._callback_tensor_inputs` attribute of your pipeline class.
+
+        Examples:
+
+        Returns:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
+            [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
+            `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
+        """
+
+        callback = kwargs.pop("callback", None)
+        callback_steps = kwargs.pop("callback_steps", None)
+
+        if callback is not None:
+            deprecate(
+                "callback",
+                "1.0.0",
+                "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+            )
+        if callback_steps is not None:
+            deprecate(
+                "callback_steps",
+                "1.0.0",
+                "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
+            )
+
+        # 0. Default height and width to unet
+        height = height or self.unet.config.sample_size * self.vae_scale_factor
+        width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+        # 1. Check inputs
+        self.check_inputs(
+            prompt,
+            prompt_2,
+            image,
+            mask_image,
+            height,
+            width,
+            strength,
+            callback_steps,
+            output_type,
+            negative_prompt,
+            negative_prompt_2,
+            prompt_embeds,
+            negative_prompt_embeds,
+            ip_adapter_image,
+            ip_adapter_image_embeds,
+            callback_on_step_end_tensor_inputs,
+            padding_mask_crop,
+        )
+
+        self._guidance_scale = guidance_scale
+        self._guidance_rescale = guidance_rescale
+        self._clip_skip = clip_skip
+        self._cross_attention_kwargs = cross_attention_kwargs
+        self._denoising_end = denoising_end
+        self._denoising_start = denoising_start
+        self._interrupt = False
+
+        ########### AE parameters
+        self._num_timesteps = num_inference_steps
+        self._rm_guidance_scale = rm_guidance_scale
+        self._AAS = AAS
+        self._ss_steps = ss_steps
+        self._ss_scale = ss_scale
+        self._AAS_start_step = AAS_start_step
+        self._AAS_start_layer = AAS_start_layer
+        self._AAS_end_layer = AAS_end_layer
+        ###########
+
+        # 2. Define call parameters
+        if prompt is not None and isinstance(prompt, str):
+            batch_size = 1
+        elif prompt is not None and isinstance(prompt, list):
+            batch_size = len(prompt)
+        else:
+            batch_size = prompt_embeds.shape[0]
+
+        device = self._execution_device
+
+        # 3. Encode input prompt
+        text_encoder_lora_scale = (
+            self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+        )
+
+        (
+            prompt_embeds,
+            negative_prompt_embeds,
+            pooled_prompt_embeds,
+            negative_pooled_prompt_embeds,
+        ) = self.encode_prompt(
+            prompt=prompt,
+            prompt_2=prompt_2,
+            device=device,
+            num_images_per_prompt=num_images_per_prompt,
+            do_classifier_free_guidance=self.do_classifier_free_guidance,
+            negative_prompt=negative_prompt,
+            negative_prompt_2=negative_prompt_2,
+            prompt_embeds=prompt_embeds,
+            negative_prompt_embeds=negative_prompt_embeds,
+            pooled_prompt_embeds=pooled_prompt_embeds,
+            negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
+            lora_scale=text_encoder_lora_scale,
+            clip_skip=self.clip_skip,
+        )
+
+        # 4. set timesteps
+        def denoising_value_valid(dnv):
+            return isinstance(dnv, float) and 0 < dnv < 1
+
+        timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+        timesteps, num_inference_steps = self.get_timesteps(
+            num_inference_steps,
+            strength,
+            device,
+            denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
+        )
+        # check that number of inference steps is not < 1 - as this doesn't make sense
+        if num_inference_steps < 1:
+            raise ValueError(
+                f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+                f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+            )
+        # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+        latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+        # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+        is_strength_max = strength == 1.0
+
+        # 5. Preprocess mask and image
+        if padding_mask_crop is not None:
+            crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
+            resize_mode = "fill"
+        else:
+            crops_coords = None
+            resize_mode = "default"
+
+        original_image = image
+        init_image = self.image_processor.preprocess(
+            image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
+        )
+        init_image = init_image.to(dtype=torch.float32)
+
+        mask = self.mask_processor.preprocess(
+            mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
+        )
+
+        if masked_image_latents is not None:
+            masked_image = masked_image_latents
+        elif init_image.shape[1] == 4:
+            # if images are in latent space, we can't mask it
+            masked_image = None
+        else:
+            masked_image = init_image * (mask < 0.5)
+
+        # 6. Prepare latent variables
+        num_channels_latents = self.vae.config.latent_channels
+        num_channels_unet = self.unet.config.in_channels
+        return_image_latents = num_channels_unet == 4
+
+        add_noise = True if self.denoising_start is None else False
+        latents_outputs = self.prepare_latents(
+            batch_size * num_images_per_prompt,
+            num_channels_latents,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            latents,
+            image=init_image,
+            timestep=latent_timestep,
+            is_strength_max=is_strength_max,
+            add_noise=add_noise,
+            return_noise=True,
+            return_image_latents=return_image_latents,
+        )
+
+        if return_image_latents:
+            latents, noise, image_latents = latents_outputs
+        else:
+            latents, noise = latents_outputs
+
+        # 7. Prepare mask latent variables
+        mask, masked_image_latents = self.prepare_mask_latents(
+            mask,
+            masked_image,
+            batch_size * num_images_per_prompt,
+            height,
+            width,
+            prompt_embeds.dtype,
+            device,
+            generator,
+            self.do_classifier_free_guidance,
+        )
+
+        # 8. Check that sizes of mask, masked image and latents match
+        if num_channels_unet == 9:
+            # default case for runwayml/stable-diffusion-inpainting
+            num_channels_mask = mask.shape[1]
+            num_channels_masked_image = masked_image_latents.shape[1]
+            if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
+                raise ValueError(
+                    f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+                    f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+                    f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+                    f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+                    " `pipeline.unet` or your `mask_image` or `image` input."
+                )
+        elif num_channels_unet != 4:
+            raise ValueError(
+                f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+            )
+        # 8.1 Prepare extra step kwargs.
+        extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+        # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+        height, width = latents.shape[-2:]
+        height = height * self.vae_scale_factor
+        width = width * self.vae_scale_factor
+
+        original_size = original_size or (height, width)
+        target_size = target_size or (height, width)
+
+        # 10. Prepare added time ids & embeddings
+        if negative_original_size is None:
+            negative_original_size = original_size
+        if negative_target_size is None:
+            negative_target_size = target_size
+
+        add_text_embeds = pooled_prompt_embeds
+        if self.text_encoder_2 is None:
+            text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
+        else:
+            text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
+
+        add_time_ids, add_neg_time_ids = self._get_add_time_ids(
+            original_size,
+            crops_coords_top_left,
+            target_size,
+            aesthetic_score,
+            negative_aesthetic_score,
+            negative_original_size,
+            negative_crops_coords_top_left,
+            negative_target_size,
+            dtype=prompt_embeds.dtype,
+            text_encoder_projection_dim=text_encoder_projection_dim,
+        )
+        add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+
+        if self.do_classifier_free_guidance:
+            prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
+            add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
+
+        ###########
+        if self.do_self_attention_redirection_guidance:
+            prompt_embeds = torch.cat([prompt_embeds, prompt_embeds], dim=0)
+            add_text_embeds = torch.cat([add_text_embeds, add_text_embeds], dim=0)
+            add_neg_time_ids = add_neg_time_ids.repeat(2, 1)
+            add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
+        ############
+
+        prompt_embeds = prompt_embeds.to(device)
+        add_text_embeds = add_text_embeds.to(device)
+        add_time_ids = add_time_ids.to(device)
+
+        if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+            image_embeds = self.prepare_ip_adapter_image_embeds(
+                ip_adapter_image,
+                ip_adapter_image_embeds,
+                device,
+                batch_size * num_images_per_prompt,
+                self.do_classifier_free_guidance,
+            )
+
+        # apply AAS to modify the attention module
+        if self.do_self_attention_redirection_guidance:
+            self._AAS_end_step = int(strength * self._num_timesteps)
+            layer_idx = list(range(self._AAS_start_layer, self._AAS_end_layer))
+            editor = AAS_XL(
+                self._AAS_start_step,
+                self._AAS_end_step,
+                self._AAS_start_layer,
+                self._AAS_end_layer,
+                layer_idx=layer_idx,
+                mask=mask_image,
+                model_type="SDXL",
+                ss_steps=self._ss_steps,
+                ss_scale=self._ss_scale,
+            )
+            self.regiter_attention_editor_diffusers(self.unet, editor)
+
+        # 11. Denoising loop
+        num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
+
+        if (
+            self.denoising_end is not None
+            and self.denoising_start is not None
+            and denoising_value_valid(self.denoising_end)
+            and denoising_value_valid(self.denoising_start)
+            and self.denoising_start >= self.denoising_end
+        ):
+            raise ValueError(
+                f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
+                + f" {self.denoising_end} when using type float."
+            )
+        elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
+            discrete_timestep_cutoff = int(
+                round(
+                    self.scheduler.config.num_train_timesteps
+                    - (self.denoising_end * self.scheduler.config.num_train_timesteps)
+                )
+            )
+            num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
+            timesteps = timesteps[:num_inference_steps]
+
+        # 11.1 Optionally get Guidance Scale Embedding
+        timestep_cond = None
+        if self.unet.config.time_cond_proj_dim is not None:
+            guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+            timestep_cond = self.get_guidance_scale_embedding(
+                guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+            ).to(device=device, dtype=latents.dtype)
+
+        self._num_timesteps = len(timesteps)
+
+        with self.progress_bar(total=num_inference_steps) as progress_bar:
+            for i, t in enumerate(timesteps):
+                if self.interrupt:
+                    continue
+
+                # expand the latents if we are doing classifier free guidance
+                latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
+
+                # removal guidance
+                latent_model_input = (
+                    torch.cat([latents] * 2) if self.do_self_attention_redirection_guidance else latents
+                )  # CFG was disabled when SARG was used, and experiments proved that there was little difference in the effect of whether CFG was used or not
+                # latent_model_input_rm = torch.cat([latents]*2) if self.do_self_attention_redirection_guidance else latents
+
+                # concat latents, mask, masked_image_latents in the channel dimension
+                latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+                # latent_model_input = self.scheduler.scale_model_input(latent_model_input_rm, t)
+
+                if num_channels_unet == 9:
+                    latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
+
+                # predict the noise residual
+                added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+                if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+                    added_cond_kwargs["image_embeds"] = image_embeds
+                noise_pred = self.unet(
+                    latent_model_input,
+                    t,
+                    encoder_hidden_states=prompt_embeds,
+                    timestep_cond=timestep_cond,
+                    cross_attention_kwargs=self.cross_attention_kwargs,
+                    added_cond_kwargs=added_cond_kwargs,
+                    return_dict=False,
+                )[0]
+
+                # perform SARG
+                if self.do_self_attention_redirection_guidance:
+                    noise_pred_wo, noise_pred_w = noise_pred.chunk(2)
+                    delta = noise_pred_w - noise_pred_wo
+                    noise_pred = noise_pred_wo + self._rm_guidance_scale * delta
+
+                # perform guidance
+                if self.do_classifier_free_guidance:
+                    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+                    noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+                if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
+                    # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
+                    noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
+
+                # compute the previous noisy sample x_t -> x_t-1
+                latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
+
+                if num_channels_unet == 4:
+                    init_latents_proper = image_latents
+                    if self.do_classifier_free_guidance:
+                        init_mask, _ = mask.chunk(2)
+                    else:
+                        init_mask = mask
+
+                    if i < len(timesteps) - 1:
+                        noise_timestep = timesteps[i + 1]
+                        init_latents_proper = self.scheduler.add_noise(
+                            init_latents_proper, noise, torch.tensor([noise_timestep])
+                        )
+
+                    latents = (1 - init_mask) * init_latents_proper + init_mask * latents
+
+                if callback_on_step_end is not None:
+                    callback_kwargs = {}
+                    for k in callback_on_step_end_tensor_inputs:
+                        callback_kwargs[k] = locals()[k]
+                    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+                    latents = callback_outputs.pop("latents", latents)
+                    prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+                    negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+                    add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
+                    negative_pooled_prompt_embeds = callback_outputs.pop(
+                        "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
+                    )
+                    add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
+                    add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
+                    mask = callback_outputs.pop("mask", mask)
+                    masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
+
+                # call the callback, if provided
+                if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+                    progress_bar.update()
+                    if callback is not None and i % callback_steps == 0:
+                        step_idx = i // getattr(self.scheduler, "order", 1)
+                        callback(step_idx, t, latents)
+
+                if XLA_AVAILABLE:
+                    xm.mark_step()
+
+        if not output_type == "latent":
+            # make sure the VAE is in float32 mode, as it overflows in float16
+            needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
+            latents = latents[-1:]
+
+            if needs_upcasting:
+                self.upcast_vae()
+                latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
+
+            # unscale/denormalize the latents
+            # denormalize with the mean and std if available and not None
+            has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
+            has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
+            if has_latents_mean and has_latents_std:
+                latents_mean = (
+                    torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents_std = (
+                    torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
+                )
+                latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
+            else:
+                latents = latents / self.vae.config.scaling_factor
+
+            image = self.vae.decode(latents, return_dict=False)[0]
+
+            # cast back to fp16 if needed
+            if needs_upcasting:
+                self.vae.to(dtype=torch.float16)
+        else:
+            return StableDiffusionXLPipelineOutput(images=latents)
+
+        # apply watermark if available
+        if self.watermark is not None:
+            image = self.watermark.apply_watermark(image)
+
+        image = self.image_processor.postprocess(image, output_type=output_type)
+
+        if padding_mask_crop is not None:
+            image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
+
+        # Offload all models
+        self.maybe_free_model_hooks()
+
+        if not return_dict:
+            return (image,)
+
+        return StableDiffusionXLPipelineOutput(images=image)
From 1b47ff7e00d77822cdda353c09f60170de3805d7 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Thu, 16 Jan 2025 22:00:09 +0800
Subject: [PATCH 4/7] update Docs
Co-authored-by: Other Contributor 
---
 examples/community/README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/community/README.md b/examples/community/README.md
index 02777636d2d3..a62d49f0d2c0 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,7 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
 PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
 | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
 | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
-| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3)|
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
From 72ee35e4be022d1f8f287480ab31bfe09022b67a Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Sat, 18 Jan 2025 19:21:43 +0800
Subject: [PATCH 5/7] add Oral
Co-authored-by: Other Contributor 
---
 examples/community/README.md | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/examples/community/README.md b/examples/community/README.md
index a62d49f0d2c0..4c593a004893 100644
--- a/examples/community/README.md
+++ b/examples/community/README.md
@@ -77,7 +77,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
 PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
 | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
 | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [](https://huggingface.co/spaces/pcuenq/mdm) [](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
-| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
+| Stable Diffusion XL Attentive Eraser Pipeline |[[AAAI2025 Oral] Attentive Eraser](https://github.com/Anonym0u3/AttentiveEraser) is a novel tuning-free method that enhances object removal capabilities in pre-trained diffusion models.|[Stable Diffusion XL Attentive Eraser Pipeline](#stable-diffusion-xl-attentive-eraser-pipeline)|-|[Wenhao Sun](https://github.com/Anonym0u3) and [Benlei Cui](https://github.com/Benny079)|
 
 To load a custom pipeline you just need to pass the `custom_pipeline` argument to `DiffusionPipeline`, as one of the files in `diffusers/examples/community`. Feel free to send a PR with your own pipelines, we will merge them quickly.
 
From e5fc2e40e43bd132b70d1f236dc0c9cfbd6fd2bc Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Fri, 24 Jan 2025 20:46:31 +0800
Subject: [PATCH 6/7] update_review
Co-authored-by: Other Contributor 
---
 ...ne_stable_diffusion_xl_attentive_eraser.py | 96 ++++++++++++-------
 1 file changed, 64 insertions(+), 32 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index 48c01318991a..f78cedb9965a 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -81,27 +81,72 @@
     Examples:
         ```py
         >>> import torch
-        >>> from diffusers import StableDiffusionXLInpaintPipeline
+        >>> from diffusers import DDIMScheduler, DiffusionPipeline
         >>> from diffusers.utils import load_image
-
-        >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
-        ...     "stabilityai/stable-diffusion-xl-base-1.0",
-        ...     torch_dtype=torch.float16,
-        ...     variant="fp16",
-        ...     use_safetensors=True,
-        ... )
-        >>> pipe.to("cuda")
-
-        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
-        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
-
-        >>> init_image = load_image(img_url).convert("RGB")
-        >>> mask_image = load_image(mask_url).convert("RGB")
-
-        >>> prompt = "A majestic tiger sitting on a bench"
-        >>> image = pipe(
-        ...     prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
+        >>> import torch.nn.functional as F
+        >>> from torchvision.transforms.functional import to_tensor, gaussian_blur
+
+        >>> dtype = torch.float16
+        >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
+        >>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
+
+        >>> pipeline = DiffusionPipeline.from_pretrained(
+        ...    "stabilityai/stable-diffusion-xl-base-1.0",
+        ...    custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
+        ...    scheduler=scheduler,
+        ...    variant="fp16",
+        ...    use_safetensors=True,
+        ...    torch_dtype=dtype,
+        ... ).to(device)
+
+
+        >>> def preprocess_image(image_path, device):
+        ...     image = to_tensor((load_image(image_path)))
+        ...     image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
+        ...     if image.shape[1] != 3:
+        ...         image = image.expand(-1, 3, -1, -1)
+        ...         image = F.interpolate(image, (1024, 1024))
+        ...         image = image.to(dtype).to(device)
+        ...         return image
+
+        >>> def preprocess_mask(mask_path, device):
+        ...     mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
+        ...     mask = mask.unsqueeze_(0).float()  # 0 or 1
+        ...     mask = F.interpolate(mask, (1024, 1024))
+        ...     mask = gaussian_blur(mask, kernel_size=(77, 77))
+        ...     mask[mask < 0.1] = 0
+        ...     mask[mask >= 0.1] = 1
+        ...     mask = mask.to(dtype).to(device)
+        ...     return mask
+
+        >>> prompt = "" # Set prompt to null
+        >>> seed=123 
+        >>> generator = torch.Generator(device=device).manual_seed(seed)
+        >>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
+        >>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
+        >>> source_image = preprocess_image(source_image_path, device)
+        >>> mask = preprocess_mask(mask_path, device)
+
+        >>> image = pipeline(
+        ...     prompt=prompt, 
+        ...     image=source_image,
+        ...     mask_image=mask,
+        ...     height=1024,
+        ...     width=1024,
+        ...     AAS=True, # enable AAS
+        ...     strength=0.8, # inpainting strength
+        ...     rm_guidance_scale=9, # removal guidance scale
+        ...     ss_steps = 9, # similarity suppression steps
+        ...     ss_scale = 0.3, # similarity suppression scale
+        ...     AAS_start_step=0, # AAS start step
+        ...     AAS_start_layer=34, # AAS start layer
+        ...     AAS_end_layer=70, # AAS end layer
+        ...     num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
+        ...     generator=generator,
+        ...     guidance_scale=1,
         ... ).images[0]
+        >>> image.save('./removed_img.png')
+        >>> print("Object removal completed")
         ```
 """
 
@@ -174,9 +219,6 @@ def __init__(
         self.mask = mask  # mask with shape (1, 1 ,h, w)
         self.ss_steps = ss_steps
         self.ss_scale = ss_scale
-        print("AAS at denoising steps: ", self.step_idx)
-        print("AAS at U-Net layers: ", self.layer_idx)
-        print("start AAS")
         self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
         self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
         self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
@@ -209,10 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
         Attention forward function
         """
         if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
-            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
-        # B = q.shape[0] // num_heads // 2
         H = int(np.sqrt(q.shape[1]))
-        # H = W = int(np.sqrt(q.shape[1]))
         if H == 16:
             mask = self.mask_16.to(sim.device)
         elif H == 32:
@@ -317,13 +356,6 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
             dimensions: ``batch x channels x height x width``.
     """
 
-    # checkpoint. TOD(Yiyi) - need to clean this up later
-    deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
-    deprecate(
-        "prepare_mask_and_masked_image",
-        "0.30.0",
-        deprecation_message,
-    )
     if image is None:
         raise ValueError("`image` input cannot be undefined.")
 
From bb52aeb717c7dd8d85e249fbf57564687e1f3333 Mon Sep 17 00:00:00 2001
From: Anonym0u3 <306794924@qq.com>
Date: Fri, 24 Jan 2025 21:34:26 +0800
Subject: [PATCH 7/7] update_review_ms
Co-authored-by: Other Contributor 
---
 .../pipeline_stable_diffusion_xl_attentive_eraser.py       | 7 ++++---
 1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
index f78cedb9965a..1269a69f0dc3 100644
--- a/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
+++ b/examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py
@@ -87,7 +87,7 @@
         >>> from torchvision.transforms.functional import to_tensor, gaussian_blur
 
         >>> dtype = torch.float16
-        >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 
+        >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         >>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
 
         >>> pipeline = DiffusionPipeline.from_pretrained(
@@ -120,7 +120,7 @@
         ...     return mask
 
         >>> prompt = "" # Set prompt to null
-        >>> seed=123 
+        >>> seed=123
         >>> generator = torch.Generator(device=device).manual_seed(seed)
         >>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
         >>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
@@ -128,7 +128,7 @@
         >>> mask = preprocess_mask(mask_path, device)
 
         >>> image = pipeline(
-        ...     prompt=prompt, 
+        ...     prompt=prompt,
         ...     image=source_image,
         ...     mask_image=mask,
         ...     height=1024,
@@ -251,6 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
         Attention forward function
         """
         if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
+            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
         H = int(np.sqrt(q.shape[1]))
         if H == 16:
             mask = self.mask_16.to(sim.device)