Skip to content

Commit 2f86879

Browse files
committed
split attention
1 parent 4700b7f commit 2f86879

File tree

3 files changed

+119
-37
lines changed

3 files changed

+119
-37
lines changed

src/diffusers/models/attention_dispatch.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ class AttentionBackendName(str, Enum):
186186
_NATIVE_MATH = "_native_math"
187187
_NATIVE_NPU = "_native_npu"
188188
_NATIVE_XLA = "_native_xla"
189+
SPLIT = "split"
189190

190191
# `sageattention`
191192
SAGE = "sage"
@@ -503,7 +504,7 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
503504
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
504505
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
505506
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
506-
max_seqlen_q = seqlens_q.max().item()
507+
max_seqlen_q = seqlens_q.max().item() #TODO item() is inefficient and breaks torch.compile graphs. Use 'seq_len' parameter instead (see split attention backend)
507508
max_seqlen_k = seqlens_k.max().item()
508509
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
509510

@@ -1975,6 +1976,45 @@ def _native_attention(
19751976

19761977
return out
19771978

1979+
@_AttentionBackendRegistry.register(
1980+
AttentionBackendName.SPLIT,
1981+
constraints=[_check_device, _check_shape],
1982+
supports_context_parallel=True,
1983+
)
1984+
def _split_attention(
1985+
query: torch.Tensor,
1986+
key: torch.Tensor,
1987+
value: torch.Tensor,
1988+
attn_mask: Optional[torch.Tensor] = None,
1989+
seq_len: Optional[torch.Tensor] = None, #attn_mask is ignored if seq_len is passed
1990+
dropout_p: float = 0.0,
1991+
is_causal: bool = False,
1992+
scale: Optional[float] = None,
1993+
enable_gqa: bool = False,
1994+
return_lse: bool = False,
1995+
_parallel_config: Optional["ParallelConfig"] = None,
1996+
) -> torch.Tensor:
1997+
if seq_len is None:
1998+
return _native_attention(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse, _parallel_config)
1999+
2000+
batch_size, batch_seq_len = query.shape[:2]
2001+
if any(sample_seq_len > batch_seq_len for sample_seq_len in seq_len):
2002+
raise ValueError("Attention sequence lengths cannot be longer than maximum sequence length")
2003+
if len(seq_len) != batch_size:
2004+
raise ValueError("Attention sequence lengths must match the batch size")
2005+
2006+
result = []
2007+
for index, sample_seq_len in enumerate(seq_len):
2008+
sliced_query = query[index, :sample_seq_len, :, :].unsqueeze(0)
2009+
sliced_key = key [index, :sample_seq_len, :, :].unsqueeze(0)
2010+
sliced_value = value[index, :sample_seq_len, :, :].unsqueeze(0)
2011+
sliced_result = _native_attention(sliced_query, sliced_key, sliced_value, None, dropout_p, is_causal, scale, enable_gqa, return_lse, _parallel_config)
2012+
2013+
padding = torch.zeros((1, batch_seq_len - sample_seq_len) + sliced_result.shape[2:], device=sliced_result.device, dtype=sliced_result.dtype)
2014+
padded_result = torch.cat([sliced_result, padding], dim=1)
2015+
result.append(padded_result)
2016+
return torch.cat(result, dim=0)
2017+
19782018

19792019
@_AttentionBackendRegistry.register(
19802020
AttentionBackendName._NATIVE_CUDNN,

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def compute_text_seq_len_from_mask(
150150
"""
151151
batch_size, text_seq_len = encoder_hidden_states.shape[:2]
152152
if encoder_hidden_states_mask is None:
153-
return text_seq_len, None, None
153+
return text_seq_len, [text_seq_len] * batch_size, None
154154

155155
if encoder_hidden_states_mask.shape[:2] != (batch_size, text_seq_len):
156156
raise ValueError(
@@ -165,7 +165,7 @@ def compute_text_seq_len_from_mask(
165165
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
166166
has_active = encoder_hidden_states_mask.any(dim=1)
167167
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
168-
return text_seq_len, per_sample_len, encoder_hidden_states_mask
168+
return text_seq_len, per_sample_len.tolist(), encoder_hidden_states_mask
169169

170170

171171
class QwenTimestepProjEmbeddings(nn.Module):
@@ -492,6 +492,7 @@ def __call__(
492492
encoder_hidden_states_mask: torch.FloatTensor = None,
493493
attention_mask: Optional[torch.FloatTensor] = None,
494494
image_rotary_emb: Optional[torch.Tensor] = None,
495+
encoder_hidden_states_len: Optional[torch.Tensor] = None,
495496
) -> torch.FloatTensor:
496497
if encoder_hidden_states is None:
497498
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
@@ -537,16 +538,17 @@ def __call__(
537538

538539
# Concatenate for joint attention
539540
# Order: [text, image]
540-
joint_query = torch.cat([txt_query, img_query], dim=1)
541-
joint_key = torch.cat([txt_key, img_key], dim=1)
542-
joint_value = torch.cat([txt_value, img_value], dim=1)
541+
joint_query = torch.cat([img_query, txt_query], dim=1)
542+
joint_key = torch.cat([img_key, txt_key], dim=1)
543+
joint_value = torch.cat([img_value, txt_value], dim=1)
543544

544545
# If an encoder_hidden_states_mask is provided, create a joint attention mask.
545546
# The encoder_hidden_states_mask is expected to have 1.0 for valid tokens and 0.0 for padding.
546547
# We convert it to a boolean mask where True means "attend" and False means "mask out" (don't attend).
547548
# Only create the mask if there's actual padding, otherwise keep attention_mask=None for better SDPA performance.
549+
batch_size, image_seq_len = hidden_states.shape[:2]
550+
attention_kwargs = {}
548551
if encoder_hidden_states_mask is not None and attention_mask is None:
549-
batch_size, image_seq_len = hidden_states.shape[:2]
550552
text_seq_len = encoder_hidden_states.shape[1]
551553

552554
if encoder_hidden_states_mask.shape[0] != batch_size:
@@ -568,7 +570,8 @@ def __call__(
568570
)
569571
# Create 2D joint mask [batch_size, text_seq_len + image_seq_len]
570572
# The attention dispatch will normalize this and extract sequence lengths
571-
attention_mask = torch.cat([text_attention_mask, image_attention_mask], dim=1)
573+
attention_mask = torch.cat([image_attention_mask, text_attention_mask], dim=1)
574+
attention_kwargs['seq_len'] = [text_sample_len + image_seq_len for text_sample_len in encoder_hidden_states_len]
572575

573576
# Compute joint attention
574577
joint_hidden_states = dispatch_attention_fn(
@@ -580,15 +583,16 @@ def __call__(
580583
is_causal=False,
581584
backend=self._attention_backend,
582585
parallel_config=self._parallel_config,
586+
attention_kwargs=attention_kwargs,
583587
)
584588

585589
# Reshape back
586590
joint_hidden_states = joint_hidden_states.flatten(2, 3)
587591
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
588592

589593
# Split attention outputs back
590-
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
591-
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
594+
img_attn_output = joint_hidden_states[:, :image_seq_len, :] # Image part
595+
txt_attn_output = joint_hidden_states[:, image_seq_len:, :] # Text part
592596

593597
# Apply output projections
594598
img_attn_output = attn.to_out[0](img_attn_output)
@@ -694,6 +698,7 @@ def forward(
694698
encoder_hidden_states_mask: torch.Tensor,
695699
temb: torch.Tensor,
696700
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
701+
encoder_hidden_states_len: Optional[torch.Tensor] = None,
697702
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
698703
modulate_index: Optional[List[int]] = None,
699704
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -728,6 +733,7 @@ def forward(
728733
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
729734
encoder_hidden_states_mask=encoder_hidden_states_mask,
730735
image_rotary_emb=image_rotary_emb,
736+
encoder_hidden_states_len=encoder_hidden_states_len,
731737
**joint_attention_kwargs,
732738
)
733739

@@ -947,7 +953,9 @@ def forward(
947953
encoder_hidden_states = self.txt_in(encoder_hidden_states)
948954

949955
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
950-
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
956+
if torch.all(encoder_hidden_states_mask):
957+
encoder_hidden_states_mask = None
958+
text_seq_len, text_seq_len_per_sample, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
951959
encoder_hidden_states, encoder_hidden_states_mask
952960
)
953961

@@ -971,6 +979,7 @@ def forward(
971979
encoder_hidden_states_mask,
972980
temb,
973981
image_rotary_emb,
982+
text_seq_len_per_sample,
974983
attention_kwargs,
975984
modulate_index,
976985
)
@@ -982,6 +991,7 @@ def forward(
982991
encoder_hidden_states_mask=encoder_hidden_states_mask,
983992
temb=temb,
984993
image_rotary_emb=image_rotary_emb,
994+
encoder_hidden_states_len=text_seq_len_per_sample,
985995
joint_attention_kwargs=attention_kwargs,
986996
modulate_index=modulate_index,
987997
)

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 58 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ def __call__(
473473
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
474474
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
475475
max_sequence_length: int = 512,
476+
batch_negative: bool = False, #TODO remove, only for testing
476477
):
477478
r"""
478479
Function invoked when calling the pipeline for generation.
@@ -603,23 +604,35 @@ def __call__(
603604
)
604605

605606
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
606-
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
607-
prompt=prompt,
608-
prompt_embeds=prompt_embeds,
609-
prompt_embeds_mask=prompt_embeds_mask,
610-
device=device,
611-
num_images_per_prompt=num_images_per_prompt,
612-
max_sequence_length=max_sequence_length,
613-
)
614-
if do_true_cfg:
615-
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
616-
prompt=negative_prompt,
617-
prompt_embeds=negative_prompt_embeds,
618-
prompt_embeds_mask=negative_prompt_embeds_mask,
607+
if do_true_cfg and batch_negative:
608+
combined_prompt_embeds, combined_prompt_embeds_mask = self.encode_prompt(
609+
prompt=[prompt, negative_prompt],
610+
# prompt_embeds=prompt_embeds,
611+
# prompt_embeds_mask=prompt_embeds_mask,
612+
device=device,
613+
num_images_per_prompt=num_images_per_prompt,
614+
max_sequence_length=max_sequence_length,
615+
)
616+
dtype = combined_prompt_embeds.dtype
617+
else:
618+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
619+
prompt=prompt,
620+
prompt_embeds=prompt_embeds,
621+
prompt_embeds_mask=prompt_embeds_mask,
619622
device=device,
620623
num_images_per_prompt=num_images_per_prompt,
621624
max_sequence_length=max_sequence_length,
622625
)
626+
dtype = prompt_embeds.dtype
627+
if do_true_cfg:
628+
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
629+
prompt=negative_prompt,
630+
prompt_embeds=negative_prompt_embeds,
631+
prompt_embeds_mask=negative_prompt_embeds_mask,
632+
device=device,
633+
num_images_per_prompt=num_images_per_prompt,
634+
max_sequence_length=max_sequence_length,
635+
)
623636

624637
# 4. Prepare latent variables
625638
num_channels_latents = self.transformer.config.in_channels // 4
@@ -628,7 +641,7 @@ def __call__(
628641
num_channels_latents,
629642
height,
630643
width,
631-
prompt_embeds.dtype,
644+
dtype,
632645
device,
633646
generator,
634647
latents,
@@ -682,31 +695,50 @@ def __call__(
682695
self._current_timestep = t
683696
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
684697
timestep = t.expand(latents.shape[0]).to(latents.dtype)
685-
with self.transformer.cache_context("cond"):
698+
if do_true_cfg and batch_negative:
686699
noise_pred = self.transformer(
687-
hidden_states=latents,
688-
timestep=timestep / 1000,
689-
guidance=guidance,
690-
encoder_hidden_states_mask=prompt_embeds_mask,
691-
encoder_hidden_states=prompt_embeds,
700+
hidden_states=torch.cat([latents] * 2, dim=0),
701+
timestep=torch.cat([timestep] * 2, dim=0) / 1000,
702+
guidance=torch.cat([guidance] * 2, dim=0) if guidance is not None else None,
703+
encoder_hidden_states_mask=combined_prompt_embeds_mask,
704+
encoder_hidden_states=combined_prompt_embeds,
692705
img_shapes=img_shapes,
693706
attention_kwargs=self.attention_kwargs,
694707
return_dict=False,
695708
)[0]
709+
noise_pred, neg_noise_pred = torch.chunk(noise_pred, 2, dim=0)
710+
711+
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
696712

697-
if do_true_cfg:
698-
with self.transformer.cache_context("uncond"):
699-
neg_noise_pred = self.transformer(
713+
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
714+
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)
715+
noise_pred = comb_pred * (cond_norm / noise_norm)
716+
else:
717+
with self.transformer.cache_context("cond"):
718+
noise_pred = self.transformer(
700719
hidden_states=latents,
701720
timestep=timestep / 1000,
702721
guidance=guidance,
703-
encoder_hidden_states_mask=negative_prompt_embeds_mask,
704-
encoder_hidden_states=negative_prompt_embeds,
722+
encoder_hidden_states_mask=prompt_embeds_mask,
723+
encoder_hidden_states=prompt_embeds,
705724
img_shapes=img_shapes,
706725
attention_kwargs=self.attention_kwargs,
707726
return_dict=False,
708727
)[0]
709-
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
728+
729+
if do_true_cfg:
730+
with self.transformer.cache_context("uncond"):
731+
neg_noise_pred = self.transformer(
732+
hidden_states=latents,
733+
timestep=timestep / 1000,
734+
guidance=guidance,
735+
encoder_hidden_states_mask=negative_prompt_embeds_mask,
736+
encoder_hidden_states=negative_prompt_embeds,
737+
img_shapes=img_shapes,
738+
attention_kwargs=self.attention_kwargs,
739+
return_dict=False,
740+
)[0]
741+
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
710742

711743
cond_norm = torch.norm(noise_pred, dim=-1, keepdim=True)
712744
noise_norm = torch.norm(comb_pred, dim=-1, keepdim=True)

0 commit comments

Comments
 (0)