Skip to content

Commit 04fb006

Browse files
committed
fix
Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
1 parent 19a657b commit 04fb006

File tree

1 file changed

+4
-0
lines changed
  • nemo_automodel/components/models/llama_bidirectional

1 file changed

+4
-0
lines changed

nemo_automodel/components/models/llama_bidirectional/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def _update_causal_mask(
106106
):
107107
if attention_mask is None:
108108
return None
109+
if getattr(self.config, "_attn_implementation", None) == "flash_attention_2":
110+
# Flash Attention handles padding from the raw 2D mask;
111+
# bidirectional attention is ensured by is_causal=False on all layers.
112+
return attention_mask
109113
dtype = input_tensor.dtype if input_tensor is not None else torch.float32
110114
return _prepare_4d_attention_mask(attention_mask, dtype)
111115

0 commit comments

Comments
 (0)