We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 19a657b commit 04fb006Copy full SHA for 04fb006
nemo_automodel/components/models/llama_bidirectional/model.py
@@ -106,6 +106,10 @@ def _update_causal_mask(
106
):
107
if attention_mask is None:
108
return None
109
+ if getattr(self.config, "_attn_implementation", None) == "flash_attention_2":
110
+ # Flash Attention handles padding from the raw 2D mask;
111
+ # bidirectional attention is ensured by is_causal=False on all layers.
112
+ return attention_mask
113
dtype = input_tensor.dtype if input_tensor is not None else torch.float32
114
return _prepare_4d_attention_mask(attention_mask, dtype)
115
0 commit comments