Skip to content

Commit 5aeb51a

Browse files
committed
Update pipeline_pag_controlnet_sd_img2img.py
1 parent f366d09 commit 5aeb51a

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_img2img.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,6 +1234,13 @@ def __call__(
12341234
device,
12351235
generator,
12361236
)
1237+
1238+
timestep_cond = None
1239+
if self.unet.config.time_cond_proj_dim is not None:
1240+
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1241+
timestep_cond = self.get_guidance_scale_embedding(
1242+
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1243+
).to(device=device, dtype=latents.dtype)
12371244

12381245
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
12391246
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -1317,6 +1324,7 @@ def __call__(
13171324
latent_model_input,
13181325
t,
13191326
encoder_hidden_states=prompt_embeds,
1327+
timestep_cond=timestep_cond,
13201328
cross_attention_kwargs=self.cross_attention_kwargs,
13211329
down_block_additional_residuals=down_block_res_samples,
13221330
mid_block_additional_residual=mid_block_res_sample,
@@ -1344,6 +1352,7 @@ def __call__(
13441352

13451353
latents = callback_outputs.pop("latents", latents)
13461354
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1355+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
13471356
# call the callback, if provided
13481357
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
13491358
progress_bar.update()

0 commit comments

Comments
 (0)