@@ -513,7 +513,7 @@ def prepare_latents(
513513        shape  =  (batch_size , num_channels_latents , height , width )
514514
515515        if  latents  is  not None :
516-             latent_image_ids  =  self ._prepare_latent_image_ids (batch_size , height   //   2 , width   //   2 , device , dtype )
516+             latent_image_ids  =  self ._prepare_latent_image_ids (batch_size , height , width , device , dtype )
517517            return  latents .to (device = device , dtype = dtype ), latent_image_ids 
518518
519519        if  isinstance (generator , list ) and  len (generator ) !=  batch_size :
@@ -567,7 +567,6 @@ def __call__(
567567        callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] =  None ,
568568        callback_on_step_end_tensor_inputs : List [str ] =  ["latents" ],
569569        max_sequence_length : int  =  512 ,
570-         img_cond : Optional [torch .FloatTensor ] =  None ,
571570    ):
572571        r""" 
573572        Function invoked when calling the pipeline for generation. 
@@ -687,7 +686,7 @@ def __call__(
687686        )
688687
689688        # 4. Prepare latent variables 
690-         num_channels_latents  =  self .transformer .config .out_channels  //  4 
689+         num_channels_latents  =  self .transformer .config .in_channels  //  4 
691690        latents , latent_image_ids  =  self .prepare_latents (
692691            batch_size  *  num_images_per_prompt ,
693692            num_channels_latents ,
@@ -699,8 +698,6 @@ def __call__(
699698            latents ,
700699        )
701700
702-         img_cond  =  img_cond .to (latents .device ) if  img_cond  is  not None  else  None 
703- 
704701        # 5. Prepare timesteps 
705702        sigmas  =  np .linspace (1.0 , 1  /  num_inference_steps , num_inference_steps )
706703        image_seq_len  =  latents .shape [1 ]
@@ -739,7 +736,7 @@ def __call__(
739736                timestep  =  t .expand (latents .shape [0 ]).to (latents .dtype )
740737
741738                noise_pred  =  self .transformer (
742-                     hidden_states = torch . cat (( latents ,  img_cond ),  dim = 2 )  if   img_cond   is   not   None   else   latents ,
739+                     hidden_states = latents ,
743740                    timestep = timestep  /  1000 ,
744741                    guidance = guidance ,
745742                    pooled_projections = pooled_prompt_embeds ,
0 commit comments