Skip to content

Commit c009c20

Browse files
authored
[SDXL] Fix uncaught error with image to image (#8856)
* initial commit * apply suggestion to sdxl pipelines * apply fix to sd pipelines
1 parent 3f14117 commit c009c20

File tree

8 files changed

+56
-0
lines changed

8 files changed

+56
-0
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
824824
)
825825

826826
elif isinstance(generator, list):
827+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
828+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
829+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
830+
raise ValueError(
831+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
832+
)
833+
827834
init_latents = [
828835
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
829836
for i in range(batch_size)

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,13 @@ def prepare_latents(
930930
)
931931

932932
elif isinstance(generator, list):
933+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
934+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
935+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
936+
raise ValueError(
937+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
938+
)
939+
933940
init_latents = [
934941
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
935942
for i in range(batch_size)

src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,13 @@ def prepare_latents(
528528
)
529529

530530
elif isinstance(generator, list):
531+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
532+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
533+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
534+
raise ValueError(
535+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
536+
)
537+
531538
init_latents = [
532539
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
533540
for i in range(batch_size)

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
520520
)
521521

522522
elif isinstance(generator, list):
523+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
524+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
525+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
526+
raise ValueError(
527+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
528+
)
529+
523530
init_latents = [
524531
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
525532
for i in range(batch_size)

src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,13 @@ def prepare_latents(
719719
)
720720

721721
elif isinstance(generator, list):
722+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
723+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
724+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
725+
raise ValueError(
726+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
727+
)
728+
722729
init_latents = [
723730
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
724731
for i in range(batch_size)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
494494
)
495495

496496
elif isinstance(generator, list):
497+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
498+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
499+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
500+
raise ValueError(
501+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
502+
)
503+
497504
init_latents = [
498505
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
499506
for i in range(batch_size)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,13 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
740740
)
741741

742742
elif isinstance(generator, list):
743+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
744+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
745+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
746+
raise ValueError(
747+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
748+
)
749+
743750
init_latents = [
744751
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
745752
for i in range(batch_size)

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,13 @@ def prepare_latents(
710710
)
711711

712712
elif isinstance(generator, list):
713+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
714+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
715+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
716+
raise ValueError(
717+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
718+
)
719+
713720
init_latents = [
714721
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
715722
for i in range(batch_size)

0 commit comments

Comments
 (0)