Skip to content

Commit 1808130

Browse files
committed
batched cfg implementation for qwenimage edit.
1 parent 4acbfbf commit 1808130

File tree

1 file changed

+59
-19
lines changed

1 file changed

+59
-19
lines changed

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@ def __call__(
546546
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
547547
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
548548
max_sequence_length: int = 512,
549+
batch_cfg: bool = False,
549550
):
550551
r"""
551552
Function invoked when calling the pipeline for generation.
@@ -712,6 +713,14 @@ def __call__(
712713
num_images_per_prompt=num_images_per_prompt,
713714
max_sequence_length=max_sequence_length,
714715
)
716+
if batch_cfg:
717+
target_len = max(negative_prompt_embeds.size(1), prompt_embeds.size(1))
718+
negative_prompt_embeds = self._pad_to_len(negative_prompt_embeds, target_len, pad_value=0.0)
719+
prompt_embeds = self._pad_to_len(prompt_embeds, target_len, pad_value=0.0)
720+
negative_prompt_embeds_mask = self._pad_to_len(negative_prompt_embeds_mask, target_len, pad_value=0)
721+
prompt_embeds_mask = self._pad_to_len(prompt_embeds_mask, target_len, pad_value=0)
722+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
723+
prompt_embeds_mask = torch.cat([negative_prompt_embeds_mask, prompt_embeds_mask], dim=0)
715724

716725
# 4. Prepare latent variables
717726
num_channels_latents = self.transformer.config.in_channels // 4
@@ -732,7 +741,9 @@ def __call__(
732741
(1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
733742
]
734743
] * batch_size
735-
744+
if batch_cfg:
745+
img_shapes = img_shapes * 2
746+
736747
# 5. Prepare timesteps
737748
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
738749
image_seq_len = latents.shape[1]
@@ -771,9 +782,10 @@ def __call__(
771782
self._attention_kwargs = {}
772783

773784
txt_seq_lens = prompt_embeds_mask.sum(dim=1).tolist() if prompt_embeds_mask is not None else None
774-
negative_txt_seq_lens = (
775-
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
776-
)
785+
if not batch_cfg:
786+
negative_txt_seq_lens = (
787+
negative_prompt_embeds_mask.sum(dim=1).tolist() if negative_prompt_embeds_mask is not None else None
788+
)
777789

778790
# 6. Denoising loop
779791
self.scheduler.set_begin_index(0)
@@ -787,9 +799,14 @@ def __call__(
787799
latent_model_input = latents
788800
if image_latents is not None:
789801
latent_model_input = torch.cat([latents, image_latents], dim=1)
802+
if batch_cfg:
803+
latent_model_input = torch.cat([latent_model_input] * 2)
790804

791805
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
792-
timestep = t.expand(latents.shape[0]).to(latents.dtype)
806+
if batch_cfg:
807+
timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype)
808+
else:
809+
timestep = t.expand(latents.shape[0]).to(latents.dtype)
793810
with self.transformer.cache_context("cond"):
794811
noise_pred = self.transformer(
795812
hidden_states=latent_model_input,
@@ -802,22 +819,25 @@ def __call__(
802819
attention_kwargs=self.attention_kwargs,
803820
return_dict=False,
804821
)[0]
805-
noise_pred = noise_pred[:, : latents.size(1)]
822+
noise_pred = noise_pred[:, : latents.size(1)]
806823

807824
if do_true_cfg:
808-
with self.transformer.cache_context("uncond"):
809-
neg_noise_pred = self.transformer(
810-
hidden_states=latent_model_input,
811-
timestep=timestep / 1000,
812-
guidance=guidance,
813-
encoder_hidden_states_mask=negative_prompt_embeds_mask,
814-
encoder_hidden_states=negative_prompt_embeds,
815-
img_shapes=img_shapes,
816-
txt_seq_lens=negative_txt_seq_lens,
817-
attention_kwargs=self.attention_kwargs,
818-
return_dict=False,
819-
)[0]
820-
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
825+
if not batch_cfg:
826+
with self.transformer.cache_context("uncond"):
827+
neg_noise_pred = self.transformer(
828+
hidden_states=latent_model_input,
829+
timestep=timestep / 1000,
830+
guidance=guidance,
831+
encoder_hidden_states_mask=negative_prompt_embeds_mask,
832+
encoder_hidden_states=negative_prompt_embeds,
833+
img_shapes=img_shapes,
834+
txt_seq_lens=negative_txt_seq_lens,
835+
attention_kwargs=self.attention_kwargs,
836+
return_dict=False,
837+
)[0]
838+
neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
839+
else:
840+
neg_noise_pred, noise_pred = noise_pred.chunk(2, dim=0)
821841
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
822842

823843
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
@@ -874,3 +894,23 @@ def __call__(
874894
return (image,)
875895

876896
return QwenImagePipelineOutput(images=image)
897+
898+
@staticmethod
899+
def _pad_to_len(x, target_len, pad_value=0.0):
900+
# x: [B, S, D] or [B, S]
901+
if x.dim() == 3: # embeds
902+
B, S, D = x.shape
903+
if S == target_len:
904+
return x
905+
out = x.new_full((B, target_len, D), pad_value)
906+
out[:, :S, :] = x
907+
return out
908+
elif x.dim() == 2: # mask
909+
B, S = x.shape
910+
if S == target_len:
911+
return x
912+
out = x.new_zeros((B, target_len), dtype=x.dtype)
913+
out[:, :S] = x
914+
return out
915+
else:
916+
raise ValueError("Unexpected tensor rank")

0 commit comments

Comments
 (0)