Skip to content

Commit e5fc2e4

Browse files
Anonym0u3Benny079
andcommitted
update_review
Co-authored-by: Other Contributor <[email protected]>
1 parent c2d5b91 commit e5fc2e4

File tree

1 file changed

+64
-32
lines changed

1 file changed

+64
-32
lines changed

examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py

Lines changed: 64 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -81,27 +81,72 @@
8181
Examples:
8282
```py
8383
>>> import torch
84-
>>> from diffusers import StableDiffusionXLInpaintPipeline
84+
>>> from diffusers import DDIMScheduler, DiffusionPipeline
8585
>>> from diffusers.utils import load_image
86-
87-
>>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
88-
... "stabilityai/stable-diffusion-xl-base-1.0",
89-
... torch_dtype=torch.float16,
90-
... variant="fp16",
91-
... use_safetensors=True,
92-
... )
93-
>>> pipe.to("cuda")
94-
95-
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
96-
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
97-
98-
>>> init_image = load_image(img_url).convert("RGB")
99-
>>> mask_image = load_image(mask_url).convert("RGB")
100-
101-
>>> prompt = "A majestic tiger sitting on a bench"
102-
>>> image = pipe(
103-
... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
86+
>>> import torch.nn.functional as F
87+
>>> from torchvision.transforms.functional import to_tensor, gaussian_blur
88+
89+
>>> dtype = torch.float16
90+
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
91+
>>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
92+
93+
>>> pipeline = DiffusionPipeline.from_pretrained(
94+
... "stabilityai/stable-diffusion-xl-base-1.0",
95+
... custom_pipeline="pipeline_stable_diffusion_xl_attentive_eraser",
96+
... scheduler=scheduler,
97+
... variant="fp16",
98+
... use_safetensors=True,
99+
... torch_dtype=dtype,
100+
... ).to(device)
101+
102+
103+
>>> def preprocess_image(image_path, device):
104+
... image = to_tensor((load_image(image_path)))
105+
... image = image.unsqueeze_(0).float() * 2 - 1 # [0,1] --> [-1,1]
106+
... if image.shape[1] != 3:
107+
... image = image.expand(-1, 3, -1, -1)
108+
... image = F.interpolate(image, (1024, 1024))
109+
... image = image.to(dtype).to(device)
110+
... return image
111+
112+
>>> def preprocess_mask(mask_path, device):
113+
... mask = to_tensor((load_image(mask_path, convert_method=lambda img: img.convert('L'))))
114+
... mask = mask.unsqueeze_(0).float() # 0 or 1
115+
... mask = F.interpolate(mask, (1024, 1024))
116+
... mask = gaussian_blur(mask, kernel_size=(77, 77))
117+
... mask[mask < 0.1] = 0
118+
... mask[mask >= 0.1] = 1
119+
... mask = mask.to(dtype).to(device)
120+
... return mask
121+
122+
>>> prompt = "" # Set prompt to null
123+
>>> seed=123
124+
>>> generator = torch.Generator(device=device).manual_seed(seed)
125+
>>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
126+
>>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
127+
>>> source_image = preprocess_image(source_image_path, device)
128+
>>> mask = preprocess_mask(mask_path, device)
129+
130+
>>> image = pipeline(
131+
... prompt=prompt,
132+
... image=source_image,
133+
... mask_image=mask,
134+
... height=1024,
135+
... width=1024,
136+
... AAS=True, # enable AAS
137+
... strength=0.8, # inpainting strength
138+
... rm_guidance_scale=9, # removal guidance scale
139+
... ss_steps = 9, # similarity suppression steps
140+
... ss_scale = 0.3, # similarity suppression scale
141+
... AAS_start_step=0, # AAS start step
142+
... AAS_start_layer=34, # AAS start layer
143+
... AAS_end_layer=70, # AAS end layer
144+
... num_inference_steps=50, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
145+
... generator=generator,
146+
... guidance_scale=1,
104147
... ).images[0]
148+
>>> image.save('./removed_img.png')
149+
>>> print("Object removal completed")
105150
```
106151
"""
107152

@@ -174,9 +219,6 @@ def __init__(
174219
self.mask = mask # mask with shape (1, 1 ,h, w)
175220
self.ss_steps = ss_steps
176221
self.ss_scale = ss_scale
177-
print("AAS at denoising steps: ", self.step_idx)
178-
print("AAS at U-Net layers: ", self.layer_idx)
179-
print("start AAS")
180222
self.mask_16 = F.max_pool2d(mask, (1024 // 16, 1024 // 16)).round().squeeze().squeeze()
181223
self.mask_32 = F.max_pool2d(mask, (1024 // 32, 1024 // 32)).round().squeeze().squeeze()
182224
self.mask_64 = F.max_pool2d(mask, (1024 // 64, 1024 // 64)).round().squeeze().squeeze()
@@ -209,10 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
209251
Attention forward function
210252
"""
211253
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
212-
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
213-
# B = q.shape[0] // num_heads // 2
214254
H = int(np.sqrt(q.shape[1]))
215-
# H = W = int(np.sqrt(q.shape[1]))
216255
if H == 16:
217256
mask = self.mask_16.to(sim.device)
218257
elif H == 32:
@@ -317,13 +356,6 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
317356
dimensions: ``batch x channels x height x width``.
318357
"""
319358

320-
# checkpoint. TOD(Yiyi) - need to clean this up later
321-
deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
322-
deprecate(
323-
"prepare_mask_and_masked_image",
324-
"0.30.0",
325-
deprecation_message,
326-
)
327359
if image is None:
328360
raise ValueError("`image` input cannot be undefined.")
329361

0 commit comments

Comments
 (0)