Skip to content

Commit bb52aeb

Browse files
Anonym0u3Benny079
andcommitted
update_review_ms
Co-authored-by: Other Contributor <[email protected]>
1 parent e5fc2e4 commit bb52aeb

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

examples/community/pipeline_stable_diffusion_xl_attentive_eraser.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
>>> from torchvision.transforms.functional import to_tensor, gaussian_blur
8888
8989
>>> dtype = torch.float16
90-
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
90+
>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
9191
>>> scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
9292
9393
>>> pipeline = DiffusionPipeline.from_pretrained(
@@ -120,15 +120,15 @@
120120
... return mask
121121
122122
>>> prompt = "" # Set prompt to null
123-
>>> seed=123
123+
>>> seed=123
124124
>>> generator = torch.Generator(device=device).manual_seed(seed)
125125
>>> source_image_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024.png"
126126
>>> mask_path = "https://raw.githubusercontent.com/Anonym0u3/Images/refs/heads/main/an1024_mask.png"
127127
>>> source_image = preprocess_image(source_image_path, device)
128128
>>> mask = preprocess_mask(mask_path, device)
129129
130130
>>> image = pipeline(
131-
... prompt=prompt,
131+
... prompt=prompt,
132132
... image=source_image,
133133
... mask_image=mask,
134134
... height=1024,
@@ -251,6 +251,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
251251
Attention forward function
252252
"""
253253
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
254+
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
254255
H = int(np.sqrt(q.shape[1]))
255256
if H == 16:
256257
mask = self.mask_16.to(sim.device)

0 commit comments

Comments
 (0)