@@ -575,33 +575,16 @@ def prepare_image(
575575 image ,
576576 width ,
577577 height ,
578- batch_size ,
579- num_images_per_prompt ,
580578 device ,
581579 dtype ,
582- do_classifier_free_guidance = False ,
583- guess_mode = False ,
584580 ):
585581 if isinstance (image , torch .Tensor ):
586582 pass
587583 else :
588584 image = self .image_processor .preprocess (image , height = height , width = width )
589585
590- image_batch_size = image .shape [0 ]
591-
592- if image_batch_size == 1 :
593- repeat_by = batch_size
594- else :
595- # image batch size is the same as prompt batch size
596- repeat_by = num_images_per_prompt
597-
598- image = image .repeat_interleave (repeat_by , dim = 0 )
599-
600586 image = image .to (device = device , dtype = dtype )
601587
602- if do_classifier_free_guidance and not guess_mode :
603- image = torch .cat ([image ] * 2 )
604-
605588 return image
606589
607590 # Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
@@ -626,12 +609,6 @@ def prepare_latents(self,
626609 int (width ) // self .vae_scale_factor ,
627610 )
628611
629- image = image .to (device = device , dtype = dtype )
630- if isinstance (image , torch .Tensor ):
631- pass
632- else :
633- image = self .image_processor .preprocess (image , height = height , width = width )
634- image = image .to (device = device , dtype = self .vae .dtype )
635612
636613 if image .shape [1 ] != num_channels_latents :
637614 image = self .vae .encode (image ).latent
@@ -840,8 +817,7 @@ def __call__(
840817 lora_scale = self .attention_kwargs .get ("scale" , None ) if self .attention_kwargs is not None else None
841818
842819 # 2. Preprocess image
843- init_image = self .image_processor .preprocess (image , height = height , width = width )
844- init_image = init_image .to (dtype = torch .float32 )
820+ init_image = self .prepare_image (image , width , height , device , self .vae .dtype )
845821
846822 # 3. Encode input prompt
847823 (
0 commit comments