|  | 
| 81 | 81 |     Examples: | 
| 82 | 82 |         ```py | 
| 83 | 83 |         >>> import torch | 
| 84 |  | -        >>> from diffusers import StableDiffusionXLInpaintPipeline | 
|  | 84 | +        >>> from diffusers import DDIMScheduler, DiffusionPipeline | 
| 85 | 85 |         >>> from diffusers.utils import load_image | 
| 86 |  | -
 | 
| 87 |  | -        >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained( | 
| 88 |  | -        ...     "stabilityai/stable-diffusion-xl-base-1.0", | 
| 89 |  | -        ...     torch_dtype=torch.float16, | 
| 90 |  | -        ...     variant="fp16", | 
| 91 |  | -        ...     use_safetensors=True, | 
| 92 |  | -        ... ) | 
| 93 |  | -        >>> pipe.to("cuda") | 
| 94 |  | -
 | 
| 95 |  | -        >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" | 
| 96 |  | -        >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" | 
| 97 |  | -
 | 
| 98 |  | -        >>> init_image = load_image(img_url).convert("RGB") | 
| 99 |  | -        >>> mask_image = load_image(mask_url).convert("RGB") | 
| 100 |  | -
 | 
| 101 |  | -        >>> prompt = "A majestic tiger sitting on a bench" | 
| 102 |  | -        >>> image = pipe( | 
| 103 |  | -        ...     prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80 | 
|  | 86 | +        >>> import torch.nn.functional as F | 
|  | 87 | +        >>> from torchvision.transforms.functional import to_tensor, gaussian_blur | 
|  | 88 | +
 | 
|  | 89 | +        >>> dtype = torch.float16 | 
|  | 90 | +        >>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")  | 
|  | 91 | +        >>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) | 
|  | 92 | +
 | 
|  | 93 | +        >>> pipeline = DiffusionPipeline.from_pretrained( | 
|  | 94 | +        ...    "stabilityai/stable-diffusion-xl-base-1.0", | 
|  | 95 | +        ...    custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser", | 
|  | 96 | +        ...    scheduler=scheduler, | 
|  | 97 | +        ...    variant="fp16", | 
|  | 98 | +        ...    use_safetensors=True, | 
|  | 99 | +        ...    torch_dtype=dtype, | 
|  | 100 | +        ... ).to(device) | 
|  | 101 | +
 | 
|  | 102 | +
 | 
|  | 103 | +        >>> def preprocess_image(image_path, device): | 
|  | 104 | +        ...     image = to_tensor((load_image(image_path))) | 
|  | 105 | +        ...     image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1] | 
|  | 106 | +        ...     if image.shape[1] != 3: | 
|  | 107 | +        ...         image = image.expand(-1, 3, -1, -1) | 
|  | 108 | +        ...         image = F.interpolate(image, (1024, 1024)) | 
|  | 109 | +        ...         image = image.to(dtype).to(device) | 
|  | 110 | +        ...         return image | 
|  | 111 | +
 | 
|  | 112 | +        >>> def preprocess_mask(mask_path, device): | 
|  | 113 | +        ...     mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L')))) | 
|  | 114 | +        ...     mask = mask.unsqueeze_(0).float()  # 0 or 1 | 
|  | 115 | +        ...     mask = F.interpolate(mask, (1024, 1024)) | 
|  | 116 | +        ...     mask = gaussian_blur(mask, kernel_size=(77, 77)) | 
|  | 117 | +        ...     mask[mask < 0.1] = 0 | 
|  | 118 | +        ...     mask[mask >= 0.1] = 1 | 
|  | 119 | +        ...     mask = mask.to(dtype).to(device) | 
|  | 120 | +        ...     return mask | 
|  | 121 | +
 | 
|  | 122 | +        >>> prompt = "" # Set prompt to null | 
|  | 123 | +        >>> seed=123  | 
|  | 124 | +        >>> generator = torch.Generator(device=device).manual_seed(seed) | 
|  | 125 | +        >>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png" | 
|  | 126 | +        >>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png" | 
|  | 127 | +        >>> source_image = preprocess_image(source_image_path, device) | 
|  | 128 | +        >>> mask = preprocess_mask(mask_path, device) | 
|  | 129 | +
 | 
|  | 130 | +        >>> image = pipeline( | 
|  | 131 | +        ...     prompt=prompt,  | 
|  | 132 | +        ...     image=source_image, | 
|  | 133 | +        ...     mask_image=mask, | 
|  | 134 | +        ...     height=1024, | 
|  | 135 | +        ...     width=1024, | 
|  | 136 | +        ...     AAS=True, # enable AAS | 
|  | 137 | +        ...     strength=0.8, # inpainting strength | 
|  | 138 | +        ...     rm_guidance_scale=9, # removal guidance scale | 
|  | 139 | +        ...     ss_steps = 9, # similarity suppression steps | 
|  | 140 | +        ...     ss_scale = 0.3, # similarity suppression scale | 
|  | 141 | +        ...     AAS_start_step=0, # AAS start step | 
|  | 142 | +        ...     AAS_start_layer=34, # AAS start layer | 
|  | 143 | +        ...     AAS_end_layer=70, # AAS end layer | 
|  | 144 | +        ...     num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps) | 
|  | 145 | +        ...     generator=generator, | 
|  | 146 | +        ...     guidance_scale=1, | 
| 104 | 147 |         ... ).images[0] | 
|  | 148 | +        >>> image.save('./removed_img.png') | 
|  | 149 | +        >>> print("Object removal completed") | 
| 105 | 150 |         ``` | 
| 106 | 151 | """ | 
| 107 | 152 | 
 | 
| @@ -174,9 +219,6 @@ def __init__( | 
| 174 | 219 |         self.mask = mask  # mask with shape (1, 1 ,h, w) | 
| 175 | 220 |         self.ss_steps = ss_steps | 
| 176 | 221 |         self.ss_scale = ss_scale | 
| 177 |  | -        print("AAS at denoising steps: ", self.step_idx) | 
| 178 |  | -        print("AAS at U-Net layers: ", self.layer_idx) | 
| 179 |  | -        print("start AAS") | 
| 180 | 222 |         self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze() | 
| 181 | 223 |         self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze() | 
| 182 | 224 |         self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze() | 
| @@ -209,10 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar | 
| 209 | 251 |         Attention forward function | 
| 210 | 252 |         """ | 
| 211 | 253 |         if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx: | 
| 212 |  | -            return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs) | 
| 213 |  | -        # B = q.shape[0] // num_heads // 2 | 
| 214 | 254 |         H = int(np.sqrt(q.shape[1])) | 
| 215 |  | -        # H = W = int(np.sqrt(q.shape[1])) | 
| 216 | 255 |         if H == 16: | 
| 217 | 256 |             mask = self.mask_16.to(sim.device) | 
| 218 | 257 |         elif H == 32: | 
| @@ -317,13 +356,6 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool | 
| 317 | 356 |             dimensions: ``batch x channels x height x width``. | 
| 318 | 357 |     """ | 
| 319 | 358 | 
 | 
| 320 |  | -    # checkpoint. TOD(Yiyi) - need to clean this up later | 
| 321 |  | -    deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead" | 
| 322 |  | -    deprecate( | 
| 323 |  | -        "prepare_mask_and_masked_image", | 
| 324 |  | -        "0.30.0", | 
| 325 |  | -        deprecation_message, | 
| 326 |  | -    ) | 
| 327 | 359 |     if image is None: | 
| 328 | 360 |         raise ValueError("`image` input cannot be undefined.") | 
| 329 | 361 | 
 | 
|  | 
0 commit comments