Skip to content

Commit 9fcb4b9

Browse files
authored
[Gluon][Tutorial] Tweak GROUP_SIZE_N (#7448)
fp16+d128 was falling off at 32K+ N_CTX due to pid scheduling issues. Tweak `GROUP_SIZE_N` to get better perf at longer contexts. ``` Attention Z=4 H=32 D=128 causal=False: N_CTX triton-fp16 0 4096.0 1278.753121 1 8192.0 1270.658425 2 16384.0 1263.709358 3 32768.0 1241.498713 4 65536.0 1216.785238 Attention Z=4 H=32 D=128 causal=True: N_CTX triton-fp16 0 4096.0 1008.655006 1 8192.0 1143.577101 2 16384.0 1136.524452 3 32768.0 1166.343494 4 65536.0 1068.798635 ``` This doesn't seem to significantly affect fp8 or d64 perf.
1 parent 973461c commit 9fcb4b9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/tutorials/gluon/01-attention-forward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ def attention_forward(q, k, v, causal, sm_scale):
894894
BLOCK_M = 256
895895
BLOCK_N = 128
896896
SPLIT_M = BLOCK_M // 2
897-
GROUP_SIZE_N = 8
897+
GROUP_SIZE_N = 4 if causal else 1
898898
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
899899

900900
desc_q = make_tensor_desc(q, shape=[y_dim, HEAD_DIM_K], strides=[HEAD_DIM_K, 1], block_shape=[SPLIT_M, HEAD_DIM_K])

0 commit comments

Comments
 (0)