Skip to content

Commit aac7aba

Browse files
committed
update
1 parent ff13dee commit aac7aba

File tree

1 file changed

+17
-37
lines changed

1 file changed

+17
-37
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 17 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3579,9 +3579,16 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):
35793579
key = torch.cat([key, encoder_key], dim=2)
35803580
value = torch.cat([value, encoder_value], dim=2)
35813581

3582+
# Zero out tokens based on the attention mask
3583+
query = query * attention_mask[:, None, :, None]
3584+
key = key * attention_mask[:, None, :, None]
3585+
value = value * attention_mask[:, None, :, None]
3586+
35823587
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
3588+
35833589
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
3584-
hidden_states = hidden_states.to(query.dtype)
3590+
# Zero out tokens based on attention mask
3591+
hidden_states = hidden_states * attention_mask[:, :, None]
35853592

35863593
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
35873594
(sequence_length, encoder_sequence_length), dim=1
@@ -5053,46 +5060,19 @@ def __init__(self):
50535060

50545061
AttentionProcessor = Union[
50555062
AttnProcessor,
5056-
CustomDiffusionAttnProcessor,
5063+
AttnProcessor2_0,
5064+
FusedAttnProcessor2_0,
5065+
XFormersAttnProcessor,
5066+
SlicedAttnProcessor,
50575067
AttnAddedKVProcessor,
5068+
SlicedAttnAddedKVProcessor,
50585069
AttnAddedKVProcessor2_0,
5059-
JointAttnProcessor2_0,
5060-
PAGJointAttnProcessor2_0,
5061-
PAGCFGJointAttnProcessor2_0,
5062-
FusedJointAttnProcessor2_0,
5063-
AllegroAttnProcessor2_0,
5064-
AuraFlowAttnProcessor2_0,
5065-
FusedAuraFlowAttnProcessor2_0,
5066-
FluxAttnProcessor2_0,
5067-
FluxAttnProcessor2_0_NPU,
5068-
FusedFluxAttnProcessor2_0,
5069-
FusedFluxAttnProcessor2_0_NPU,
5070-
CogVideoXAttnProcessor2_0,
5071-
FusedCogVideoXAttnProcessor2_0,
50725070
XFormersAttnAddedKVProcessor,
5073-
XFormersAttnProcessor,
5074-
AttnProcessorNPU,
5075-
AttnProcessor2_0,
5076-
MochiVaeAttnProcessor2_0,
5077-
StableAudioAttnProcessor2_0,
5078-
HunyuanAttnProcessor2_0,
5079-
FusedHunyuanAttnProcessor2_0,
5080-
PAGHunyuanAttnProcessor2_0,
5081-
PAGCFGHunyuanAttnProcessor2_0,
5082-
LuminaAttnProcessor2_0,
5083-
MochiAttnProcessor2_0,
5084-
FusedAttnProcessor2_0,
5071+
CustomDiffusionAttnProcessor,
50855072
CustomDiffusionXFormersAttnProcessor,
50865073
CustomDiffusionAttnProcessor2_0,
5087-
SlicedAttnProcessor,
5088-
SlicedAttnAddedKVProcessor,
5089-
IPAdapterAttnProcessor,
5090-
IPAdapterAttnProcessor2_0,
5091-
IPAdapterXFormersAttnProcessor,
5092-
PAGIdentitySelfAttnProcessor2_0,
50935074
PAGCFGIdentitySelfAttnProcessor2_0,
5094-
LoRAAttnProcessor,
5095-
LoRAAttnProcessor2_0,
5096-
LoRAXFormersAttnProcessor,
5097-
LoRAAttnAddedKVProcessor,
5075+
PAGIdentitySelfAttnProcessor2_0,
5076+
PAGCFGHunyuanAttnProcessor2_0,
5077+
PAGHunyuanAttnProcessor2_0,
50985078
]

0 commit comments

Comments
 (0)