@@ -421,26 +421,27 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
421421
422422        return  latents 
423423
424+     # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_img2img.QwenImageImg2ImgPipeline._encode_vae_image 
424425    def  _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
425426        if  isinstance (generator , list ):
426427            image_latents  =  [
427-                 retrieve_latents (self .vae .encode (image [i  : i  +  1 ]), generator = generator [i ],  sample_mode = "argmax" )
428+                 retrieve_latents (self .vae .encode (image [i  : i  +  1 ]), generator = generator [i ])
428429                for  i  in  range (image .shape [0 ])
429430            ]
430431            image_latents  =  torch .cat (image_latents , dim = 0 )
431432        else :
432-             image_latents  =  retrieve_latents (self .vae .encode (image ), generator = generator , sample_mode = "argmax" )
433+             image_latents  =  retrieve_latents (self .vae .encode (image ), generator = generator )
434+ 
433435        latents_mean  =  (
434436            torch .tensor (self .vae .config .latents_mean )
435-             .view (1 , self .latent_channels , 1 , 1 , 1 )
437+             .view (1 , self .vae . config . z_dim , 1 , 1 , 1 )
436438            .to (image_latents .device , image_latents .dtype )
437439        )
438-         latents_std  =  (
439-             torch .tensor (self .vae .config .latents_std )
440-             .view (1 , self .latent_channels , 1 , 1 , 1 )
441-             .to (image_latents .device , image_latents .dtype )
440+         latents_std  =  1.0  /  torch .tensor (self .vae .config .latents_std ).view (1 , self .vae .config .z_dim , 1 , 1 , 1 ).to (
441+             image_latents .device , image_latents .dtype 
442442        )
443-         image_latents  =  (image_latents  -  latents_mean ) /  latents_std 
443+ 
444+         image_latents  =  (image_latents  -  latents_mean ) *  latents_std 
444445
445446        return  image_latents 
446447
@@ -485,6 +486,7 @@ def disable_vae_tiling(self):
485486        """ 
486487        self .vae .disable_tiling ()
487488
489+     # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_inpaint.QwenImageInpaintPipeline.prepare_latents 
488490    def  prepare_latents (
489491        self ,
490492        image ,
@@ -510,25 +512,32 @@ def prepare_latents(
510512
511513        shape  =  (batch_size , 1 , num_channels_latents , height , width )
512514
513-         image_latents  =  None 
514-         if  image  is  not None :
515-             image  =  image .to (device = device , dtype = dtype )
516-             if  image .shape [1 ] !=  self .latent_channels :
517-                 image_latents  =  self ._encode_vae_image (image = image , generator = generator )
518-             else :
519-                 image_latents  =  image 
520-             if  batch_size  >  image_latents .shape [0 ] and  batch_size  %  image_latents .shape [0 ] ==  0 :
521-                 # expand init_latents for batch_size 
522-                 additional_image_per_prompt  =  batch_size  //  image_latents .shape [0 ]
523-                 image_latents  =  torch .cat ([image_latents ] *  additional_image_per_prompt , dim = 0 )
524-             elif  batch_size  >  image_latents .shape [0 ] and  batch_size  %  image_latents .shape [0 ] !=  0 :
525-                 raise  ValueError (
526-                     f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} { batch_size }  
527-                 )
528-             else :
529-                 image_latents  =  torch .cat ([image_latents ], dim = 0 )
515+         # If image is [B,C,H,W] -> add T=1. If it's already [B,C,T,H,W], leave it. 
516+         if  image .dim () ==  4 :
517+             image  =  image .unsqueeze (2 )
518+         elif  image .dim () !=  5 :
519+             raise  ValueError (f"Expected image dims 4 or 5, got { image .dim ()}  )
520+ 
521+         if  latents  is  not None :
522+             return  latents .to (device = device , dtype = dtype )
523+ 
524+         image  =  image .to (device = device , dtype = dtype )
525+         if  image .shape [1 ] !=  self .latent_channels :
526+             image_latents  =  self ._encode_vae_image (image = image , generator = generator )  # [B,z,1,H',W'] 
527+         else :
528+             image_latents  =  image 
529+         if  batch_size  >  image_latents .shape [0 ] and  batch_size  %  image_latents .shape [0 ] ==  0 :
530+             # expand init_latents for batch_size 
531+             additional_image_per_prompt  =  batch_size  //  image_latents .shape [0 ]
532+             image_latents  =  torch .cat ([image_latents ] *  additional_image_per_prompt , dim = 0 )
533+         elif  batch_size  >  image_latents .shape [0 ] and  batch_size  %  image_latents .shape [0 ] !=  0 :
534+             raise  ValueError (
535+                 f"Cannot duplicate `image` of batch size { image_latents .shape [0 ]} { batch_size }  
536+             )
537+         else :
538+             image_latents  =  torch .cat ([image_latents ], dim = 0 )
530539
531-              image_latents  =  image_latents .transpose (1 , 2 )  # [B,1,z,H',W'] 
540+         image_latents  =  image_latents .transpose (1 , 2 )  # [B,1,z,H',W'] 
532541
533542        if  latents  is  None :
534543            noise  =  randn_tensor (shape , generator = generator , device = device , dtype = dtype )
@@ -655,7 +664,7 @@ def __call__(
655664        strength : float  =  0.6 ,
656665        num_inference_steps : int  =  50 ,
657666        sigmas : Optional [List [float ]] =  None ,
658-         guidance_scale : float  =  1.0 ,
667+         guidance_scale : Optional [ float ]  =  None ,
659668        num_images_per_prompt : int  =  1 ,
660669        generator : Optional [Union [torch .Generator , List [torch .Generator ]]] =  None ,
661670        latents : Optional [torch .Tensor ] =  None ,
@@ -674,6 +683,12 @@ def __call__(
674683        Function invoked when calling the pipeline for generation. 
675684
676685        Args: 
686+             image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 
687+                 `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both 
688+                 numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list 
689+                 or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a 
690+                 list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image 
691+                 latents as `image`, but if passing latents directly it is not encoded again. 
677692            prompt (`str` or `List[str]`, *optional*): 
678693                The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 
679694                instead. 
@@ -682,7 +697,12 @@ def __call__(
682697                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is 
683698                not greater than `1`). 
684699            true_cfg_scale (`float`, *optional*, defaults to 1.0): 
685-                 When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. 
700+                 true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free 
701+                 Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of 
702+                 equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is 
703+                 enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale 
704+                 encourages to generate images that are closely linked to the text `prompt`, usually at the expense of 
705+                 lower image quality. 
686706            mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 
687707                `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask 
688708                are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a 
@@ -717,17 +737,16 @@ def __call__(
717737                Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 
718738                their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 
719739                will be used. 
720-             guidance_scale (`float`, *optional*, defaults to 3.5): 
721-                 Guidance scale as defined in [Classifier-Free Diffusion 
722-                 Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. 
723-                 of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting 
724-                 `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to 
725-                 the text `prompt`, usually at the expense of lower image quality. 
726- 
727-                 This parameter in the pipeline is there to support future guidance-distilled models when they come up. 
728-                 Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance, 
729-                 please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should 
730-                 enable classifier-free guidance computations. 
740+             guidance_scale (`float`, *optional*, defaults to None): 
741+                 A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance 
742+                 where the guidance scale is applied during inference through noise prediction rescaling, guidance 
743+                 distilled models take the guidance scale directly as an input parameter during forward pass. Guidance 
744+                 scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images 
745+                 that are closely linked to the text `prompt`, usually at the expense of lower image quality. This 
746+                 parameter in the pipeline is there to support future guidance-distilled models when they come up. It is 
747+                 ignored when not using guidance distilled models. To enable traditional classifier-free guidance, 
748+                 please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should 
749+                 enable classifier-free guidance computations). 
731750            num_images_per_prompt (`int`, *optional*, defaults to 1): 
732751                The number of images to generate per prompt. 
733752            generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 
@@ -831,11 +850,20 @@ def __call__(
831850                image , height = calculated_height , width = calculated_width , crops_coords = crops_coords , resize_mode = resize_mode 
832851            )
833852            image  =  image .to (dtype = torch .float32 )
834-             image  =  image .unsqueeze (2 )
835853
836854        has_neg_prompt  =  negative_prompt  is  not None  or  (
837855            negative_prompt_embeds  is  not None  and  negative_prompt_embeds_mask  is  not None 
838856        )
857+ 
858+         if  true_cfg_scale  >  1  and  not  has_neg_prompt :
859+             logger .warning (
860+                 f"true_cfg_scale is passed as { true_cfg_scale }  
861+             )
862+         elif  true_cfg_scale  <=  1  and  has_neg_prompt :
863+             logger .warning (
864+                 " negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1" 
865+             )
866+ 
839867        do_true_cfg  =  true_cfg_scale  >  1  and  has_neg_prompt 
840868        prompt_embeds , prompt_embeds_mask  =  self .encode_prompt (
841869            image = prompt_image ,
@@ -932,10 +960,17 @@ def __call__(
932960        self ._num_timesteps  =  len (timesteps )
933961
934962        # handle guidance 
935-         if  self .transformer .config .guidance_embeds :
963+         if  self .transformer .config .guidance_embeds  and  guidance_scale  is  None :
964+             raise  ValueError ("guidance_scale is required for guidance-distilled model." )
965+         elif  self .transformer .config .guidance_embeds :
936966            guidance  =  torch .full ([1 ], guidance_scale , device = device , dtype = torch .float32 )
937967            guidance  =  guidance .expand (latents .shape [0 ])
938-         else :
968+         elif  not  self .transformer .config .guidance_embeds  and  guidance_scale  is  not None :
969+             logger .warning (
970+                 f"guidance_scale is passed as { guidance_scale }  
971+             )
972+             guidance  =  None 
973+         elif  not  self .transformer .config .guidance_embeds  and  guidance_scale  is  None :
939974            guidance  =  None 
940975
941976        if  self .attention_kwargs  is  None :
0 commit comments