Skip to content

Commit 5a311bb

Browse files
authored
[Gluon][Tutorial] Merge d64 and d128 attn kernels (#7226)
This PR deletes the d64 kernel and only uses the d128 kernel by enabling BLOCK_N=128 for d64 and increasing the number of load buffers, as well as reordering code in the softmax partition to improve performance. This increases the peak performance up to 770 tflops (780 for non-causal) for DHEAD=64. However, this makes performance for smaller ctx (16k and below) worse than before because the scheduling granularity is worse. This should get fixed when the kernel is made persistent, and this means I don't have to optimize two separate kernels. DHEAD=128 performance is unchanged. ``` Attention Z=4 H=32 D=64 causal=False: N_CTX triton-fp16 cudnn-fp16 0 1024.0 381.658388 480.100511 1 2048.0 569.136004 804.057444 2 4096.0 663.284730 892.582773 3 8192.0 737.043055 939.454979 4 16384.0 764.685874 952.196135 5 32768.0 778.020934 931.158252 6 65536.0 785.854946 912.826103 Attention Z=4 H=32 D=64 causal=True: N_CTX triton-fp16 cudnn-fp16 0 1024.0 167.438219 351.715759 1 2048.0 295.866693 619.581918 2 4096.0 452.670120 754.473029 3 8192.0 586.122336 833.803614 4 16384.0 659.646168 831.281960 5 32768.0 743.090000 910.116397 6 65536.0 771.990185 918.770994 Attention Z=4 H=32 D=128 causal=False: N_CTX triton-fp16 cudnn-fp16 0 1024.0 629.445279 915.936310 1 2048.0 930.445757 1222.156019 2 4096.0 984.129646 1306.298210 3 8192.0 1107.853967 1382.667255 4 16384.0 1175.943239 1265.863845 5 32768.0 1187.172469 1248.898664 6 65536.0 1174.481087 1268.167581 Attention Z=4 H=32 D=128 causal=True: N_CTX triton-fp16 cudnn-fp16 0 1024.0 265.095023 547.764710 1 2048.0 472.862887 852.318204 2 4096.0 690.310855 1085.992183 3 8192.0 882.526991 1247.025302 4 16384.0 1062.367442 1292.686167 5 32768.0 1131.631485 1196.174915 6 65536.0 1122.032238 1248.103686 ```
1 parent 0b9853e commit 5a311bb

File tree

1 file changed

+52
-231
lines changed

1 file changed

+52
-231
lines changed

0 commit comments

Comments
 (0)