Skip to content

Commit d978e80

Browse files
authored
Fix attention mask type for Flash Attention + CP + THD (NVIDIA#1354)
* always have padding mask type for both flash and fused attentions Signed-off-by: Xiaowei Ren <[email protected]> * remove an redundant assert Signed-off-by: Xiaowei Ren <[email protected]> --------- Signed-off-by: Xiaowei Ren <[email protected]>
1 parent 8c00424 commit d978e80

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

tests/pytorch/fused_attn/run_fused_attn_with_cp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def run_dpa_with_cp(
4242
"causal",
4343
"no_mask",
4444
], f"{config.attn_mask_type} is an unsupported attention mask type!"
45-
if kernel_backend == "FusedAttention" and qkv_format == "thd":
45+
if qkv_format == "thd":
4646
if "causal" in config.attn_mask_type:
4747
config.attn_mask_type = "padding_causal"
4848
else:

transformer_engine/pytorch/attention.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4309,14 +4309,6 @@ def attn_forward_func_with_cp(
43094309
assert (
43104310
qkv_format != "sbhd" or use_fused_attention
43114311
), "FlashAttention does not support sbhd format!"
4312-
assert (
4313-
qkv_format != "thd"
4314-
or not use_fused_attention
4315-
or attn_mask_type in ["padding", "padding_causal"]
4316-
), (
4317-
f"Context parallelism is not supported for {attn_mask_type} mask type and "
4318-
f"{qkv_format} format with {'FusedAttention' if use_fused_attention else 'FlashAttention'}!"
4319-
)
43204312
assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), (
43214313
"""Attention bias is only supported with FusedAttention and "causal" """
43224314
"""or "no_mask" mask types!"""
@@ -7878,6 +7870,9 @@ def forward(
78787870
), f"Values have head_dim = {value_layer.shape[-1]}, "
78797871
"but expected head_dim = {self.hidden_size_per_attention_head_v}!"
78807872

7873+
if qkv_format is None:
7874+
qkv_format = self.qkv_format
7875+
78817876
if attn_mask_type is None:
78827877
attn_mask_type = self.attn_mask_type
78837878
else:
@@ -7904,9 +7899,6 @@ def forward(
79047899
graph_safe_rng_available()
79057900
), "Upgrade PyTorch version to get RNG manipulation support for cuda graph capture."
79067901

7907-
if qkv_format is None:
7908-
qkv_format = self.qkv_format
7909-
79107902
if inference_params is not None:
79117903
assert self.layer_number is not None, "Layer number must be set!"
79127904

0 commit comments

Comments
 (0)