Skip to content

Commit a70f29d

Browse files
committed
fix image encoding
1 parent 43711dd commit a70f29d

File tree

1 file changed

+1
-25
lines changed

1 file changed

+1
-25
lines changed

src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -575,33 +575,16 @@ def prepare_image(
575575
image,
576576
width,
577577
height,
578-
batch_size,
579-
num_images_per_prompt,
580578
device,
581579
dtype,
582-
do_classifier_free_guidance=False,
583-
guess_mode=False,
584580
):
585581
if isinstance(image, torch.Tensor):
586582
pass
587583
else:
588584
image = self.image_processor.preprocess(image, height=height, width=width)
589585

590-
image_batch_size = image.shape[0]
591-
592-
if image_batch_size == 1:
593-
repeat_by = batch_size
594-
else:
595-
# image batch size is the same as prompt batch size
596-
repeat_by = num_images_per_prompt
597-
598-
image = image.repeat_interleave(repeat_by, dim=0)
599-
600586
image = image.to(device=device, dtype=dtype)
601587

602-
if do_classifier_free_guidance and not guess_mode:
603-
image = torch.cat([image] * 2)
604-
605588
return image
606589

607590
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
@@ -626,12 +609,6 @@ def prepare_latents(self,
626609
int(width) // self.vae_scale_factor,
627610
)
628611

629-
image = image.to(device=device, dtype=dtype)
630-
if isinstance(image, torch.Tensor):
631-
pass
632-
else:
633-
image = self.image_processor.preprocess(image, height=height, width=width)
634-
image = image.to(device=device, dtype=self.vae.dtype)
635612

636613
if image.shape[1] != num_channels_latents:
637614
image = self.vae.encode(image).latent
@@ -840,8 +817,7 @@ def __call__(
840817
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
841818

842819
# 2. Preprocess image
843-
init_image = self.image_processor.preprocess(image, height=height, width=width)
844-
init_image = init_image.to(dtype=torch.float32)
820+
init_image = self.prepare_image(image, width, height, device, self.vae.dtype)
845821

846822
# 3. Encode input prompt
847823
(

0 commit comments

Comments
 (0)