Skip to content

Fix: audio encoder uses full attention instead of windowed on non-FA2 backends#103

Open
Dvad wants to merge 1 commit intoQwenLM:mainfrom
Dvad:main
Open

Fix: audio encoder uses full attention instead of windowed on non-FA2 backends#103
Dvad wants to merge 1 commit intoQwenLM:mainfrom
Dvad:main

Conversation

@Dvad
Copy link

@Dvad Dvad commented Feb 28, 2026

Bug: Qwen3ASRAudioEncoder.forward() passes cu_seqlens to attention layers but never builds a 4D attention mask. Only Flash Attention 2 (CUDA) interprets cu_seqlens for windowed attention boundaries.

On SDPA/eager backends (MPS, CPU), cu_seqlens is ignored and the encoder performs full global self-attention over all tokens instead of the trained windowed attention pattern.

This causes significant quality degradation on non-CUDA hardware (~340 words transcribed vs ~555 expected on a 5-minute test clip).

Fix: Call the existing _prepare_attention_mask() method (which already returns None for FA2) and pass the resulting block-diagonal mask to each encoder layer.

           cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32)
  +        attention_mask = self._prepare_attention_mask(hidden_states, cu_seqlens)
           for encoder_layer in self.layers:
               layer_outputs = encoder_layer(
                   hidden_states,
                   cu_seqlens,
  +                attention_mask=attention_mask,
               )

Verified on: MPS (Apple Silicon), CPU.

FA2 path is unchanged (_prepare_attention_mask returns None for flash_attention_2).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant