Skip to content

Commit 235d34c

Browse files
Check for latents, before calling prepare_latents - sdxlImg2Img (#7582)
* Check for latents, before calling prepare_latents - sdxlImg2Img * Added latents check for all the img2img pipeline * Fixed silly mistake while checking latents as None
1 parent 5029673 commit 235d34c

10 files changed

+99
-83
lines changed

examples/community/clip_guided_stable_diffusion_img2img.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,16 @@ def __call__(
359359

360360
# Preprocess image
361361
image = preprocess(image, width, height)
362-
latents = self.prepare_latents(
363-
image, latent_timestep, batch_size, num_images_per_prompt, text_embeddings.dtype, self.device, generator
364-
)
362+
if latents is None:
363+
latents = self.prepare_latents(
364+
image,
365+
latent_timestep,
366+
batch_size,
367+
num_images_per_prompt,
368+
text_embeddings.dtype,
369+
self.device,
370+
generator,
371+
)
365372

366373
if clip_guidance_scale > 0:
367374
if clip_prompt is not None:

examples/community/latent_consistency_img2img.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -335,17 +335,18 @@ def __call__(
335335

336336
# 5. Prepare latent variable
337337
num_channels_latents = self.unet.config.in_channels
338-
latents = self.prepare_latents(
339-
image,
340-
latent_timestep,
341-
batch_size * num_images_per_prompt,
342-
num_channels_latents,
343-
height,
344-
width,
345-
prompt_embeds.dtype,
346-
device,
347-
latents,
348-
)
338+
if latents is None:
339+
latents = self.prepare_latents(
340+
image,
341+
latent_timestep,
342+
batch_size * num_images_per_prompt,
343+
num_channels_latents,
344+
height,
345+
width,
346+
prompt_embeds.dtype,
347+
device,
348+
latents,
349+
)
349350
bs = batch_size * num_images_per_prompt
350351

351352
# 6. Get Guidance Scale Embedding

examples/community/stable_diffusion_controlnet_img2img.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -802,15 +802,16 @@ def __call__(
802802
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
803803

804804
# 6. Prepare latent variables
805-
latents = self.prepare_latents(
806-
image,
807-
latent_timestep,
808-
batch_size,
809-
num_images_per_prompt,
810-
prompt_embeds.dtype,
811-
device,
812-
generator,
813-
)
805+
if latents is None:
806+
latents = self.prepare_latents(
807+
image,
808+
latent_timestep,
809+
batch_size,
810+
num_images_per_prompt,
811+
prompt_embeds.dtype,
812+
device,
813+
generator,
814+
)
814815

815816
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
816817
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

examples/community/stable_diffusion_controlnet_inpaint_img2img.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -907,15 +907,16 @@ def __call__(
907907
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
908908

909909
# 6. Prepare latent variables
910-
latents = self.prepare_latents(
911-
image,
912-
latent_timestep,
913-
batch_size,
914-
num_images_per_prompt,
915-
prompt_embeds.dtype,
916-
device,
917-
generator,
918-
)
910+
if latents is None:
911+
latents = self.prepare_latents(
912+
image,
913+
latent_timestep,
914+
batch_size,
915+
num_images_per_prompt,
916+
prompt_embeds.dtype,
917+
device,
918+
generator,
919+
)
919920

920921
mask_image_latents = self.prepare_mask_latents(
921922
mask_image,

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1169,15 +1169,16 @@ def __call__(
11691169
self._num_timesteps = len(timesteps)
11701170

11711171
# 6. Prepare latent variables
1172-
latents = self.prepare_latents(
1173-
image,
1174-
latent_timestep,
1175-
batch_size,
1176-
num_images_per_prompt,
1177-
prompt_embeds.dtype,
1178-
device,
1179-
generator,
1180-
)
1172+
if latents is None:
1173+
latents = self.prepare_latents(
1174+
image,
1175+
latent_timestep,
1176+
batch_size,
1177+
num_images_per_prompt,
1178+
prompt_embeds.dtype,
1179+
device,
1180+
generator,
1181+
)
11811182

11821183
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
11831184
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,16 +1429,17 @@ def __call__(
14291429
self._num_timesteps = len(timesteps)
14301430

14311431
# 6. Prepare latent variables
1432-
latents = self.prepare_latents(
1433-
image,
1434-
latent_timestep,
1435-
batch_size,
1436-
num_images_per_prompt,
1437-
prompt_embeds.dtype,
1438-
device,
1439-
generator,
1440-
True,
1441-
)
1432+
if latents is None:
1433+
latents = self.prepare_latents(
1434+
image,
1435+
latent_timestep,
1436+
batch_size,
1437+
num_images_per_prompt,
1438+
prompt_embeds.dtype,
1439+
device,
1440+
generator,
1441+
True,
1442+
)
14421443

14431444
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
14441445
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -872,9 +872,10 @@ def __call__(
872872
else self.scheduler.config.original_inference_steps
873873
)
874874
latent_timestep = timesteps[:1]
875-
latents = self.prepare_latents(
876-
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
877-
)
875+
if latents is None:
876+
latents = self.prepare_latents(
877+
image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
878+
)
878879
bs = batch_size * num_images_per_prompt
879880

880881
# 6. Get Guidance Scale Embedding

src/diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -239,15 +239,15 @@ def __call__(
239239

240240
num_embeddings = self.prior.config.num_embeddings
241241
embedding_dim = self.prior.config.embedding_dim
242-
243-
latents = self.prepare_latents(
244-
(batch_size, num_embeddings * embedding_dim),
245-
image_embeds.dtype,
246-
device,
247-
generator,
248-
latents,
249-
self.scheduler,
250-
)
242+
if latents is None:
243+
latents = self.prepare_latents(
244+
(batch_size, num_embeddings * embedding_dim),
245+
image_embeds.dtype,
246+
device,
247+
generator,
248+
latents,
249+
self.scheduler,
250+
)
251251

252252
# YiYi notes: for testing only to match ldm, we can directly create a latents with desired shape: batch_size, num_embeddings, embedding_dim
253253
latents = latents.reshape(latents.shape[0], num_embeddings, embedding_dim)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -786,16 +786,17 @@ def __call__(
786786

787787
# 6. Prepare latent variables
788788
num_channels_latents = self.unet.config.in_channels
789-
latents = self.prepare_latents(
790-
batch_size=batch_size,
791-
num_channels_latents=num_channels_latents,
792-
height=height,
793-
width=width,
794-
dtype=prompt_embeds.dtype,
795-
device=device,
796-
generator=generator,
797-
latents=latents,
798-
)
789+
if latents is None:
790+
latents = self.prepare_latents(
791+
batch_size=batch_size,
792+
num_channels_latents=num_channels_latents,
793+
height=height,
794+
width=width,
795+
dtype=prompt_embeds.dtype,
796+
device=device,
797+
generator=generator,
798+
latents=latents,
799+
)
799800

800801
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
801802
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,17 +1247,19 @@ def denoising_value_valid(dnv):
12471247
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
12481248

12491249
add_noise = True if self.denoising_start is None else False
1250+
12501251
# 6. Prepare latent variables
1251-
latents = self.prepare_latents(
1252-
image,
1253-
latent_timestep,
1254-
batch_size,
1255-
num_images_per_prompt,
1256-
prompt_embeds.dtype,
1257-
device,
1258-
generator,
1259-
add_noise,
1260-
)
1252+
if latents is None:
1253+
latents = self.prepare_latents(
1254+
image,
1255+
latent_timestep,
1256+
batch_size,
1257+
num_images_per_prompt,
1258+
prompt_embeds.dtype,
1259+
device,
1260+
generator,
1261+
add_noise,
1262+
)
12611263
# 7. Prepare extra step kwargs.
12621264
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
12631265

0 commit comments

Comments
 (0)