Skip to content

Commit ade3d49

Browse files
authored
[Gluon][Tutorial] Persistent attention (#7298)
Rewrite the attention kernel to be persistent. This gives better performance at low-contexts. However, fp16 at large context has suffered a bit due to a ptxas instruction scheduling issue in the softmax partition. fp8 is ~100 tflops faster when the kernel name has "cutlass" in it. ``` Attention Z=4 H=32 D=64 causal=False: N_CTX triton-fp16 triton-fp8 0 1024.0 359.574448 370.119987 1 2048.0 612.103928 641.204555 2 4096.0 653.868402 682.337948 3 8192.0 692.102228 721.555690 4 16384.0 696.972041 726.190035 5 32768.0 698.723685 727.983456 6 65536.0 699.865817 728.558321 Attention Z=4 H=32 D=64 causal=True: N_CTX triton-fp16 triton-fp8 0 1024.0 181.879039 177.982453 1 2048.0 441.315463 454.310072 2 4096.0 532.170527 539.995252 3 8192.0 633.620646 638.544937 4 16384.0 667.687180 670.681255 5 32768.0 684.276329 688.571907 6 65536.0 692.953202 694.648353 Attention Z=4 H=32 D=128 causal=False: N_CTX triton-fp16 triton-fp8 0 1024.0 718.580015 709.863720 1 2048.0 1133.490258 1222.548477 2 4096.0 1247.605551 1369.800195 3 8192.0 1243.482713 1406.799697 4 16384.0 1125.744367 1514.857403 5 32768.0 1124.116305 1521.267973 6 65536.0 1064.588719 1518.738037 Attention Z=4 H=32 D=128 causal=True: N_CTX triton-fp16 triton-fp8 0 1024.0 355.642522 351.161232 1 2048.0 846.404095 854.547917 2 4096.0 1013.840017 1021.676435 3 8192.0 1176.258395 1152.844234 4 16384.0 1190.290681 1325.786204 5 32768.0 1063.658200 1394.413325 6 65536.0 970.531569 1413.282610 ```
1 parent 915f62d commit ade3d49

File tree

5 files changed

+466
-512
lines changed

5 files changed

+466
-512
lines changed

lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
128128
if (swizzleBytes != 0) {
129129
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
130130
if (blockShape[contigDim] < contigDimSize) {
131-
llvm::reportFatalUsageError("Block shape is too small for the swizzle "
132-
"byte size in NVMMA Shared Layout.");
131+
llvm::report_fatal_error("Block shape is too small for the swizzle byte "
132+
"size in NVMMA Shared Layout.");
133133
}
134134
blockShape[contigDim] = contigDimSize;
135135
}

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22
from typing import List, Optional
3-
from triton.language.core import _unwrap_if_constexpr, _unwrap_shape
3+
from triton.language.core import _unwrap_if_constexpr, _unwrap_shape, constexpr_type
44

55
__all__ = [
66
"BlockedLayout",
@@ -25,7 +25,10 @@ class DistributedLayout:
2525
"""
2626
Base class for distributed memory layouts in Gluon IR.
2727
"""
28-
pass
28+
29+
@property
30+
def type(self):
31+
return constexpr_type(self)
2932

3033

3134
@dataclass(frozen=True)
@@ -213,7 +216,10 @@ class SharedLayout:
213216
"""
214217
Base class for shared memory layouts in Gluon IR.
215218
"""
216-
pass
219+
220+
@property
221+
def type(self):
222+
return constexpr_type(self)
217223

218224

219225
@dataclass(frozen=True)

python/triton/experimental/gluon/language/_standard.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from . import _core as ttgl
77

88
_IMPORT_FROM_TRITON = [
9+
"cdiv",
910
"sum",
1011
"max",
1112
"min",

python/triton/experimental/gluon/language/nvidia/hopper/tma.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22
from typing import List, Tuple, TYPE_CHECKING
33
from dataclasses import dataclass
4+
from triton.language.core import base_type, base_value
45
import triton.experimental.gluon.language._core as ttgl
56
from triton.experimental.gluon.language._layouts import NVMMASharedLayout
67
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
@@ -12,7 +13,7 @@
1213

1314

1415
@dataclass(eq=True)
15-
class tensor_descriptor_type:
16+
class tensor_descriptor_type(base_type):
1617
block_type: ttgl.block_type
1718
shape_type: ttgl.tuple_type
1819
strides_type: ttgl.tuple_type
@@ -44,7 +45,7 @@ def mangle(self) -> str:
4445
return f"TD{self.block_type.mangle}_{self.layout.mangle()}TD"
4546

4647

47-
class tensor_descriptor:
48+
class tensor_descriptor(base_value):
4849

4950
def __init__(self, handle, shape: List[ttgl.tensor], strides: List[ttgl.tensor], block_type: ttgl.block_type,
5051
layout: NVMMASharedLayout):

0 commit comments

Comments
 (0)