Skip to content

Commit 1793a04

Browse files
authored
[Gluon] Move get_tmem_32x32b_reg_layout into blackwell API (#7535)
1 parent 896cbdb commit 1793a04

File tree

2 files changed

+52
-41
lines changed

2 files changed

+52
-41
lines changed

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from typing import Optional, Tuple, List, TYPE_CHECKING
33

44
from dataclasses import dataclass
5+
import triton
56
from triton.experimental.gluon.language import _core as ttgl
6-
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr
7+
from triton.experimental.gluon.language._core import builtin, base_type, base_value, _unwrap_if_constexpr, constexpr_function
8+
from triton.experimental.gluon.language._layouts import BlockedLayout, _get_shape_per_cta
79
from triton.experimental.gluon.language._semantic import _check
810

911
from . import tma
@@ -19,6 +21,7 @@
1921
"allocate_tensor_memory",
2022
"async_copy",
2123
"fence_async_shared",
24+
"get_tmem_32x32b_reg_layout",
2225
"mbarrier",
2326
"tensor_memory_descriptor",
2427
"TensorMemoryLayout",
@@ -59,6 +62,47 @@ def mangle(self) -> str:
5962
return f"TL{block_str}{unpacked_str}{cta_split_str}TL"
6063

6164

65+
@constexpr_function
66+
def get_tmem_32x32b_reg_layout(M, N, shape, num_warps, ctas_per_cga=None, cta_split_num=None, cta_order=None):
67+
"""Returns a BlockedLayout compatible with load/store on tensor memory with the 32x32b instruction variant.
68+
"""
69+
assert len(shape) == 2, "expected a 2D tensor"
70+
assert num_warps in [4, 8], "expected 4 or 8 warps"
71+
72+
shape_per_cta = _get_shape_per_cta(shape, cta_split_num)
73+
blocks_per_tile = [shape_per_cta[0] // M, shape_per_cta[1] // N]
74+
num_blocks = blocks_per_tile[0] * blocks_per_tile[1]
75+
76+
num_warp_groups = num_warps // 4
77+
if M == 64:
78+
threads_per_warp = [16, 2]
79+
if num_blocks == 1:
80+
size_per_thread = [1, N // (num_warp_groups * 2)]
81+
warps_per_cta = [4, num_warp_groups]
82+
else:
83+
size_per_thread = [1, N // 2]
84+
warps_per_cta = [4 * min(blocks_per_tile[0], num_warp_groups)]
85+
warps_per_cta.append(triton.cdiv(num_warp_groups, warps_per_cta[0] // 4))
86+
else:
87+
if shape[0] > 128:
88+
size_per_thread = [1, N]
89+
threads_per_warp = [32, 1]
90+
warps_per_cta = [4 * num_warp_groups, 1]
91+
else:
92+
size_per_thread = [1, N // num_warp_groups]
93+
threads_per_warp = [32, 1]
94+
warps_per_cta = [4, num_warp_groups]
95+
return BlockedLayout(
96+
size_per_thread=size_per_thread,
97+
threads_per_warp=threads_per_warp,
98+
warps_per_cta=warps_per_cta,
99+
order=[0, 1],
100+
ctas_per_cga=ctas_per_cga,
101+
cta_split_num=cta_split_num,
102+
cta_order=cta_order,
103+
)
104+
105+
62106
class tensor_memory_descriptor_type(base_type):
63107

64108
def __init__(self, element_ty, shape, layout, alloc_shape):

python/tutorials/gluon/01-attention-forward.py

Lines changed: 7 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from triton.experimental.gluon.language.nvidia.blackwell import (
1313
TensorMemoryLayout,
1414
allocate_tensor_memory,
15+
get_tmem_32x32b_reg_layout,
1516
tensor_memory_descriptor,
1617
tma,
1718
mbarrier,
@@ -24,42 +25,6 @@
2425
# ===-----------------------------------------------------------------------===#
2526

2627

27-
@gl.constexpr_function
28-
def get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps):
29-
assert len(shape) == 2, "expected a 2D tensor"
30-
assert num_warps in [4, 8], "expected 4 or 8 warps"
31-
M, N, _ = instr_shape
32-
33-
blocks_per_tile = [shape[0] // M, shape[1] // N]
34-
num_blocks = blocks_per_tile[0] * blocks_per_tile[1]
35-
36-
num_warp_groups = num_warps // 4
37-
if M == 64:
38-
threads_per_warp = [16, 2]
39-
if num_blocks == 1:
40-
size_per_thread = [1, N // (num_warp_groups * 2)]
41-
warps_per_cta = [4, num_warp_groups]
42-
else:
43-
size_per_thread = [1, N // 2]
44-
warps_per_cta = [4 * min(blocks_per_tile[0], num_warp_groups)]
45-
warps_per_cta.append(triton.cdiv(num_warp_groups, warps_per_cta[0] // 4))
46-
else:
47-
if shape[0] > 128:
48-
size_per_thread = [1, N]
49-
threads_per_warp = [32, 1]
50-
warps_per_cta = [4 * num_warp_groups, 1]
51-
else:
52-
size_per_thread = [1, N // num_warp_groups]
53-
threads_per_warp = [32, 1]
54-
warps_per_cta = [4, num_warp_groups]
55-
return gl.BlockedLayout(
56-
size_per_thread=size_per_thread,
57-
threads_per_warp=threads_per_warp,
58-
warps_per_cta=warps_per_cta,
59-
order=[0, 1],
60-
)
61-
62-
6328
@gl.constexpr_function
6429
def get_mma_instr_shape(shape, element_ty):
6530
m = 128 if shape[0] >= 128 else 64
@@ -71,7 +36,7 @@ def get_mma_instr_shape(shape, element_ty):
7136
@gl.constexpr_function
7237
def get_mma_reg_layout(shape, num_warps, dtype=gl.float32):
7338
instr_shape = get_mma_instr_shape(shape, dtype)
74-
return get_tmem_32x32b_reg_layout(instr_shape, shape, num_warps)
39+
return get_tmem_32x32b_reg_layout(*instr_shape[:2], shape, num_warps)
7540

7641

7742
# ===-----------------------------------------------------------------------===#
@@ -288,10 +253,12 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
288253
self.o_tmem_layout = gl.constexpr(TensorMemoryLayout((o_instr_shape[0], o_instr_shape[1]), unpacked=True))
289254
self.p_tmem_layout = gl.constexpr(TensorMemoryLayout((qk_instr_shape[0], qk_instr_shape[1]), unpacked=False))
290255

291-
self.qk_layout = gl.constexpr(get_tmem_32x32b_reg_layout(qk_instr_shape, self.qk_shape, self.num_warps))
292-
self.o_layout = gl.constexpr(get_tmem_32x32b_reg_layout(o_instr_shape, self.o_shape, self.num_warps))
256+
self.qk_layout = gl.constexpr(
257+
get_tmem_32x32b_reg_layout(qk_instr_shape[0], qk_instr_shape[0], self.qk_shape, self.num_warps))
258+
self.o_layout = gl.constexpr(
259+
get_tmem_32x32b_reg_layout(o_instr_shape[0], o_instr_shape[1], self.o_shape, self.num_warps))
293260
self.o_splitn_layout = gl.constexpr(
294-
get_tmem_32x32b_reg_layout((o_instr_shape[0], o_instr_shape[1] // self.SPLIT_D_FACTOR, o_instr_shape[2]),
261+
get_tmem_32x32b_reg_layout(o_instr_shape[0], o_instr_shape[1] // self.SPLIT_D_FACTOR,
295262
(self.o_shape[0], self.o_shape[1] // self.SPLIT_D_FACTOR), self.num_warps))
296263
self.alpha_2d_layout = gl.constexpr(gl.BlockedLayout([1, 1], [32, 1], [self.num_warps, 1], [0, 1]))
297264

0 commit comments

Comments
 (0)