Skip to content

Commit a545ebf

Browse files
fix some more filter issues, address feedback
Signed-off-by: Sudhakar Singh <[email protected]>
1 parent 93548fc commit a545ebf

File tree

4 files changed

+16
-15
lines changed

4 files changed

+16
-15
lines changed

tests/pytorch/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -353,11 +353,11 @@ def test():
353353
backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"}
354354
if AttentionLogging._is_logging_setup is False:
355355
AttentionLogging.setup_logging()
356-
with logging_context(highest_level=AttentionLogging._log_level):
357-
for i in range(3):
358-
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
359-
_attention_backends["backend_selection_requires_update"] = True
360-
available_backends, flash_attention_backend, fused_attention_backend = test()
361-
if fused_attention_backend == FusedAttnBackend[backends[i]]:
362-
fused_attn_backends.append(fused_attention_backend)
356+
357+
for i in range(3):
358+
os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i)
359+
_attention_backends["backend_selection_requires_update"] = True
360+
available_backends, flash_attention_backend, fused_attention_backend = test()
361+
if fused_attention_backend == FusedAttnBackend[backends[i]]:
362+
fused_attn_backends.append(fused_attention_backend)
363363
return available_backends, flash_attention_backend, fused_attn_backends

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ def forward(
347347
attention_mask=attention_mask,
348348
window_size=window_size,
349349
attention_type=self.attention_type,
350+
bottom_right_alignment=(attn_mask_type not in ["causal", "padding_causal"]
351+
if bottom_right_diagonal is None else bottom_right_diagonal)
350352
)
351353
)
352354

@@ -450,8 +452,8 @@ def forward(
450452
actual_seqlens_q=actual_seqlens_q if "padding" in attn_mask_type else None,
451453
actual_seqlens_kv=actual_seqlens_kv if "padding" in attn_mask_type else None,
452454
alibi_slopes=alibi_slopes,
453-
# (This should be replaced with `bottom_right_diagonal` which is passed from the arguments)
454-
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
455+
bottom_right_alignment=(attn_mask_type not in ["causal", "padding_causal"]
456+
if bottom_right_diagonal is None else bottom_right_diagonal)
455457
)
456458
matmul_result = torch.baddbmm(
457459
matmul_result,

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,7 +1280,6 @@ def forward(
12801280
if self.layer_number == 1:
12811281
_alibi_cache["_alibi_slopes_require_update"] = True
12821282
_alibi_cache["_alibi_bias_require_update"] = True
1283-
bottom_right_alignment = (attn_mask_type not in ["causal", "padding_causal"],)
12841283
if core_attention_bias_type == "alibi":
12851284
assert (
12861285
core_attention_bias is None
@@ -1289,7 +1288,7 @@ def forward(
12891288
_alibi_cache["_num_heads"] != query_layer.shape[-2]
12901289
or _alibi_cache["_max_seqlen_q"] != max_seqlen_q
12911290
or _alibi_cache["_max_seqlen_kv"] != max_seqlen_kv
1292-
or _alibi_cache["_bottom_right_alignment"] != bottom_right_alignment
1291+
or _alibi_cache["_bottom_right_alignment"] != bottom_right_diagonal
12931292
or _alibi_cache["_alibi_slopes"] is None
12941293
):
12951294
_alibi_cache["_alibi_slopes_require_update"] = True
@@ -1471,7 +1470,7 @@ def forward(
14711470
fu_core_attention_bias_type = core_attention_bias_type
14721471
fu_core_attention_bias = core_attention_bias
14731472
if core_attention_bias_type == "alibi" and (
1474-
alibi_slopes is not None or max_seqlen_q != max_seqlen_kv
1473+
alibi_slopes is not None
14751474
):
14761475
fu_core_attention_bias_type = "post_scale_bias"
14771476
_, fu_core_attention_bias = dpa_utils.get_alibi(
@@ -1481,7 +1480,7 @@ def forward(
14811480
max_seqlen_kv,
14821481
alibi_slopes=alibi_slopes,
14831482
bias_dtype=query_layer.dtype,
1484-
bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"],
1483+
bottom_right_alignment=bottom_right_diagonal,
14851484
)
14861485
if checkpoint_core_attention:
14871486
return self._checkpointed_attention_forward(

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ class AttentionParams:
200200
`causal_bottom_right`, `padding_causal_bottom_right`, `arbitrary`}
201201
window_size : Tuple[int, int], default = None
202202
Sliding window attention size.
203-
bottom_right_diagonal: bool, default = `True`
203+
bottom_right_diagonal: bool, default = `None`
204204
Whether to align sliding window and ALiBi diagonal to the bottom right corner
205205
of the softmax matrix.
206206
alibi_slopes_shape : Optional[Union[torch.Size, List]], default = None
@@ -962,7 +962,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
962962
if (
963963
use_fused_attention
964964
and core_attention_bias_type == "alibi"
965-
and (alibi_slopes_shape is not None or max_seqlen_q != max_seqlen_kv)
965+
and (alibi_slopes_shape is not None)
966966
):
967967
fu_core_attention_bias_type = "post_scale_bias"
968968
fu_core_attention_bias_requires_grad = False

0 commit comments

Comments
 (0)