@@ -546,6 +546,7 @@ def __call__(
546546 callback_on_step_end : Optional [Callable [[int , int , Dict ], None ]] = None ,
547547 callback_on_step_end_tensor_inputs : List [str ] = ["latents" ],
548548 max_sequence_length : int = 512 ,
549+ batch_cfg : bool = False ,
549550 ):
550551 r"""
551552 Function invoked when calling the pipeline for generation.
@@ -712,6 +713,14 @@ def __call__(
712713 num_images_per_prompt = num_images_per_prompt ,
713714 max_sequence_length = max_sequence_length ,
714715 )
716+ if batch_cfg :
717+ target_len = max (negative_prompt_embeds .size (1 ), prompt_embeds .size (1 ))
718+ negative_prompt_embeds = self ._pad_to_len (negative_prompt_embeds , target_len , pad_value = 0.0 )
719+ prompt_embeds = self ._pad_to_len (prompt_embeds , target_len , pad_value = 0.0 )
720+ negative_prompt_embeds_mask = self ._pad_to_len (negative_prompt_embeds_mask , target_len , pad_value = 0 )
721+ prompt_embeds_mask = self ._pad_to_len (prompt_embeds_mask , target_len , pad_value = 0 )
722+ prompt_embeds = torch .cat ([negative_prompt_embeds , prompt_embeds ], dim = 0 )
723+ prompt_embeds_mask = torch .cat ([negative_prompt_embeds_mask , prompt_embeds_mask ], dim = 0 )
715724
716725 # 4. Prepare latent variables
717726 num_channels_latents = self .transformer .config .in_channels // 4
@@ -732,7 +741,9 @@ def __call__(
732741 (1 , calculated_height // self .vae_scale_factor // 2 , calculated_width // self .vae_scale_factor // 2 ),
733742 ]
734743 ] * batch_size
735-
744+ if batch_cfg :
745+ img_shapes = img_shapes * 2
746+
736747 # 5. Prepare timesteps
737748 sigmas = np .linspace (1.0 , 1 / num_inference_steps , num_inference_steps ) if sigmas is None else sigmas
738749 image_seq_len = latents .shape [1 ]
@@ -771,9 +782,10 @@ def __call__(
771782 self ._attention_kwargs = {}
772783
773784 txt_seq_lens = prompt_embeds_mask .sum (dim = 1 ).tolist () if prompt_embeds_mask is not None else None
774- negative_txt_seq_lens = (
775- negative_prompt_embeds_mask .sum (dim = 1 ).tolist () if negative_prompt_embeds_mask is not None else None
776- )
785+ if not batch_cfg :
786+ negative_txt_seq_lens = (
787+ negative_prompt_embeds_mask .sum (dim = 1 ).tolist () if negative_prompt_embeds_mask is not None else None
788+ )
777789
778790 # 6. Denoising loop
779791 self .scheduler .set_begin_index (0 )
@@ -787,9 +799,14 @@ def __call__(
787799 latent_model_input = latents
788800 if image_latents is not None :
789801 latent_model_input = torch .cat ([latents , image_latents ], dim = 1 )
802+ if batch_cfg :
803+ latent_model_input = torch .cat ([latent_model_input ] * 2 )
790804
791805 # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
792- timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
806+ if batch_cfg :
807+ timestep = t .expand (latent_model_input .shape [0 ]).to (latent_model_input .dtype )
808+ else :
809+ timestep = t .expand (latents .shape [0 ]).to (latents .dtype )
793810 with self .transformer .cache_context ("cond" ):
794811 noise_pred = self .transformer (
795812 hidden_states = latent_model_input ,
@@ -802,22 +819,25 @@ def __call__(
802819 attention_kwargs = self .attention_kwargs ,
803820 return_dict = False ,
804821 )[0 ]
805- noise_pred = noise_pred [:, : latents .size (1 )]
822+ noise_pred = noise_pred [:, : latents .size (1 )]
806823
807824 if do_true_cfg :
808- with self .transformer .cache_context ("uncond" ):
809- neg_noise_pred = self .transformer (
810- hidden_states = latent_model_input ,
811- timestep = timestep / 1000 ,
812- guidance = guidance ,
813- encoder_hidden_states_mask = negative_prompt_embeds_mask ,
814- encoder_hidden_states = negative_prompt_embeds ,
815- img_shapes = img_shapes ,
816- txt_seq_lens = negative_txt_seq_lens ,
817- attention_kwargs = self .attention_kwargs ,
818- return_dict = False ,
819- )[0 ]
820- neg_noise_pred = neg_noise_pred [:, : latents .size (1 )]
825+ if not batch_cfg :
826+ with self .transformer .cache_context ("uncond" ):
827+ neg_noise_pred = self .transformer (
828+ hidden_states = latent_model_input ,
829+ timestep = timestep / 1000 ,
830+ guidance = guidance ,
831+ encoder_hidden_states_mask = negative_prompt_embeds_mask ,
832+ encoder_hidden_states = negative_prompt_embeds ,
833+ img_shapes = img_shapes ,
834+ txt_seq_lens = negative_txt_seq_lens ,
835+ attention_kwargs = self .attention_kwargs ,
836+ return_dict = False ,
837+ )[0 ]
838+ neg_noise_pred = neg_noise_pred [:, : latents .size (1 )]
839+ else :
840+ neg_noise_pred , noise_pred = noise_pred .chunk (2 , dim = 0 )
821841 comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred )
822842
823843 cond_norm = torch .norm (noise_pred , dim = - 1 , keepdim = True )
@@ -874,3 +894,23 @@ def __call__(
874894 return (image ,)
875895
876896 return QwenImagePipelineOutput (images = image )
897+
898+ @staticmethod
899+ def _pad_to_len (x , target_len , pad_value = 0.0 ):
900+ # x: [B, S, D] or [B, S]
901+ if x .dim () == 3 : # embeds
902+ B , S , D = x .shape
903+ if S == target_len :
904+ return x
905+ out = x .new_full ((B , target_len , D ), pad_value )
906+ out [:, :S , :] = x
907+ return out
908+ elif x .dim () == 2 : # mask
909+ B , S = x .shape
910+ if S == target_len :
911+ return x
912+ out = x .new_zeros ((B , target_len ), dtype = x .dtype )
913+ out [:, :S ] = x
914+ return out
915+ else :
916+ raise ValueError ("Unexpected tensor rank" )
0 commit comments