@@ -3579,9 +3579,16 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin):
35793579 key = torch .cat ([key , encoder_key ], dim = 2 )
35803580 value = torch .cat ([value , encoder_value ], dim = 2 )
35813581
3582+ # Zero out tokens based on the attention mask
3583+ query = query * attention_mask [:, None , :, None ]
3584+ key = key * attention_mask [:, None , :, None ]
3585+ value = value * attention_mask [:, None , :, None ]
3586+
35823587 hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
3588+
35833589 hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
3584- hidden_states = hidden_states .to (query .dtype )
3590+ # Zero out tokens based on attention mask
3591+ hidden_states = hidden_states * attention_mask [:, :, None ]
35853592
35863593 hidden_states , encoder_hidden_states = hidden_states .split_with_sizes (
35873594 (sequence_length , encoder_sequence_length ), dim = 1
@@ -5053,46 +5060,19 @@ def __init__(self):
50535060
50545061AttentionProcessor = Union [
50555062 AttnProcessor ,
5056- CustomDiffusionAttnProcessor ,
5063+ AttnProcessor2_0 ,
5064+ FusedAttnProcessor2_0 ,
5065+ XFormersAttnProcessor ,
5066+ SlicedAttnProcessor ,
50575067 AttnAddedKVProcessor ,
5068+ SlicedAttnAddedKVProcessor ,
50585069 AttnAddedKVProcessor2_0 ,
5059- JointAttnProcessor2_0 ,
5060- PAGJointAttnProcessor2_0 ,
5061- PAGCFGJointAttnProcessor2_0 ,
5062- FusedJointAttnProcessor2_0 ,
5063- AllegroAttnProcessor2_0 ,
5064- AuraFlowAttnProcessor2_0 ,
5065- FusedAuraFlowAttnProcessor2_0 ,
5066- FluxAttnProcessor2_0 ,
5067- FluxAttnProcessor2_0_NPU ,
5068- FusedFluxAttnProcessor2_0 ,
5069- FusedFluxAttnProcessor2_0_NPU ,
5070- CogVideoXAttnProcessor2_0 ,
5071- FusedCogVideoXAttnProcessor2_0 ,
50725070 XFormersAttnAddedKVProcessor ,
5073- XFormersAttnProcessor ,
5074- AttnProcessorNPU ,
5075- AttnProcessor2_0 ,
5076- MochiVaeAttnProcessor2_0 ,
5077- StableAudioAttnProcessor2_0 ,
5078- HunyuanAttnProcessor2_0 ,
5079- FusedHunyuanAttnProcessor2_0 ,
5080- PAGHunyuanAttnProcessor2_0 ,
5081- PAGCFGHunyuanAttnProcessor2_0 ,
5082- LuminaAttnProcessor2_0 ,
5083- MochiAttnProcessor2_0 ,
5084- FusedAttnProcessor2_0 ,
5071+ CustomDiffusionAttnProcessor ,
50855072 CustomDiffusionXFormersAttnProcessor ,
50865073 CustomDiffusionAttnProcessor2_0 ,
5087- SlicedAttnProcessor ,
5088- SlicedAttnAddedKVProcessor ,
5089- IPAdapterAttnProcessor ,
5090- IPAdapterAttnProcessor2_0 ,
5091- IPAdapterXFormersAttnProcessor ,
5092- PAGIdentitySelfAttnProcessor2_0 ,
50935074 PAGCFGIdentitySelfAttnProcessor2_0 ,
5094- LoRAAttnProcessor ,
5095- LoRAAttnProcessor2_0 ,
5096- LoRAXFormersAttnProcessor ,
5097- LoRAAttnAddedKVProcessor ,
5075+ PAGIdentitySelfAttnProcessor2_0 ,
5076+ PAGCFGHunyuanAttnProcessor2_0 ,
5077+ PAGHunyuanAttnProcessor2_0 ,
50985078]
0 commit comments