Skip to content

Commit 5efc32b

Browse files
Mogballwhitneywhtsang
authored andcommitted
[Tutorial] Fix attention tutorial and enable pytests for DHEAD=128 (#7037)
1 parent 1312e26 commit 5efc32b

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def supports_host_descriptor():
3636
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
3737

3838

39+
def is_blackwell():
40+
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
41+
42+
3943
@triton.jit
4044
def _attn_fwd_inner(acc, l_i, m_i, q, #
4145
desc_k, desc_v, #
@@ -115,7 +119,7 @@ def _host_descriptor_pre_hook(nargs):
115119
if "PYTEST_VERSION" in os.environ:
116120
# Use a single config in testing for reproducibility
117121
configs = [
118-
triton.Config(dict(BLOCK_M=64, BLOCK_N=64), num_stages=4, num_warps=4, pre_hook=_host_descriptor_pre_hook),
122+
triton.Config(dict(BLOCK_M=64, BLOCK_N=64), num_stages=2, num_warps=4, pre_hook=_host_descriptor_pre_hook),
119123
]
120124

121125

@@ -484,10 +488,10 @@ def forward(ctx, q, k, v, causal, sm_scale, warp_specialize=True):
484488
y_dim = q.shape[0] * q.shape[1] * q.shape[2]
485489

486490
dummy_block = [1, 1]
487-
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM, 1], block_shape=dummy_block)
488-
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM, 1], block_shape=dummy_block)
489-
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM, 1], block_shape=dummy_block)
490-
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM, 1], block_shape=dummy_block)
491+
desc_q = TensorDescriptor(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
492+
desc_v = TensorDescriptor(v, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
493+
desc_k = TensorDescriptor(k, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
494+
desc_o = TensorDescriptor(o, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=dummy_block)
491495
else:
492496
desc_q = q
493497
desc_v = v
@@ -510,7 +514,7 @@ def grid(META):
510514
q.shape[0], q.shape[1], #
511515
desc_q, desc_k, desc_v, desc_o, #
512516
N_CTX=q.shape[2], #
513-
HEAD_DIM=HEAD_DIM, #
517+
HEAD_DIM=HEAD_DIM_K, #
514518
FP8_OUTPUT=q.dtype == torch.float8_e5m2, #
515519
STAGE=stage, #
516520
warp_specialize=warp_specialize, #
@@ -568,17 +572,12 @@ def backward(ctx, do):
568572
attention = _attention.apply
569573

570574

571-
@pytest.mark.parametrize('Z, H, N_CTX, HEAD_DIM', [
572-
(1, 2, 1024, 64),
573-
(4, 48, 128, 64),
574-
(4, 48, 256, 64),
575-
(4, 48, 512, 64),
576-
(4, 48, 1024, 64),
577-
(4, 48, 2048, 64),
578-
(4, 48, 4096, 64),
579-
])
580-
@pytest.mark.parametrize("causal", [True])
581-
@pytest.mark.parametrize("warp_specialize", [False, True])
575+
@pytest.mark.parametrize("Z", [1, 4])
576+
@pytest.mark.parametrize("H", [2, 48])
577+
@pytest.mark.parametrize("N_CTX", [128, 1024, (2 if is_hip() else 4) * 1024])
578+
@pytest.mark.parametrize("HEAD_DIM", [64, 128])
579+
@pytest.mark.parametrize("causal", [True]) # FIXME: Non-causal tests do not pass at the moment.
580+
@pytest.mark.parametrize("warp_specialize", [False, True] if is_blackwell() else [False])
582581
def test_op(Z, H, N_CTX, HEAD_DIM, causal, warp_specialize, dtype=torch.float16):
583582
torch.manual_seed(20)
584583
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device=DEVICE).normal_(mean=0.0, std=0.5).requires_grad_())

0 commit comments

Comments
 (0)