Skip to content

Commit a6e1e80

Browse files
committed
Update pipeline_pag_controlnet_sd_img2img.py
1 parent d89e785 commit a6e1e80

File tree

1 file changed

+35
-18
lines changed

1 file changed

+35
-18
lines changed

src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_img2img.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,21 +1253,21 @@ def __call__(
12531253
]
12541254
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
12551255

1256-
control_images = control_image if isinstance(control_image, list) else [control_image]
1257-
for i, single_image in enumerate(control_images):
1258-
if self.do_classifier_free_guidance:
1259-
single_image = single_image.chunk(2)[0]
1260-
1261-
if self.do_perturbed_attention_guidance:
1262-
single_image = self._prepare_perturbed_attention_guidance(
1263-
single_image, single_image, self.do_classifier_free_guidance
1264-
)
1265-
elif self.do_classifier_free_guidance:
1266-
single_image = torch.cat([single_image] * 2)
1267-
single_image = single_image.to(device)
1268-
control_images[i] = single_image
1269-
1270-
control_image = control_images if isinstance(control_image, list) else control_images[0]
1256+
# control_images = control_image if isinstance(control_image, list) else [control_image]
1257+
# for i, single_image in enumerate(control_images):
1258+
# if self.do_classifier_free_guidance:
1259+
# single_image = single_image.chunk(2)[0]
1260+
1261+
# if self.do_perturbed_attention_guidance:
1262+
# single_image = self._prepare_perturbed_attention_guidance(
1263+
# single_image, single_image, self.do_classifier_free_guidance
1264+
# )
1265+
# elif self.do_classifier_free_guidance:
1266+
# single_image = torch.cat([single_image] * 2)
1267+
# single_image = single_image.to(device)
1268+
# control_images[i] = single_image
1269+
1270+
#control_image = control_images if isinstance(control_image, list) else control_images[0]
12711271

12721272
prompt_embeds = prompt_embeds.to(device)
12731273

@@ -1285,12 +1285,22 @@ def __call__(
12851285

12861286
with self.progress_bar(total=num_inference_steps) as progress_bar:
12871287
for i, t in enumerate(timesteps):
1288+
if self.interrupt:
1289+
continue
1290+
12881291
# expand the latents if we are doing classifier free guidance
1289-
latent_model_input = torch.cat([latents] * (prompt_embeds.shape[0] // latents.shape[0]))
1292+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
12901293
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
12911294

12921295
# controlnet(s) inference
1293-
control_model_input = latent_model_input
1296+
if guess_mode and self.do_classifier_free_guidance:
1297+
# Infer ControlNet only for the conditional batch.
1298+
control_model_input = latents
1299+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1300+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1301+
else:
1302+
control_model_input = latent_model_input
1303+
controlnet_prompt_embeds = prompt_embeds
12941304

12951305
if isinstance(controlnet_keep[i], list):
12961306
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
@@ -1299,16 +1309,23 @@ def __call__(
12991309
if isinstance(controlnet_cond_scale, list):
13001310
controlnet_cond_scale = controlnet_cond_scale[0]
13011311
cond_scale = controlnet_cond_scale * controlnet_keep[i]
1312+
13021313
down_block_res_samples, mid_block_res_sample = self.controlnet(
13031314
control_model_input,
13041315
t,
13051316
encoder_hidden_states=controlnet_prompt_embeds,
13061317
controlnet_cond=control_image,
13071318
conditioning_scale=cond_scale,
1308-
guess_mode=False,
1319+
guess_mode=guess_mode,
13091320
return_dict=False,
13101321
)
13111322

1323+
if guess_mode and self.do_classifier_free_guidance:
1324+
# Inferred ControlNet only for the conditional batch.
1325+
# To apply the output of ControlNet to both the unconditional and conditional batches,
1326+
# add 0 to the unconditional batch to keep it unchanged.
1327+
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1328+
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
13121329

13131330
# predict the noise residual
13141331
noise_pred = self.unet(

0 commit comments

Comments
 (0)