12
12
from triton .experimental .gluon .language .nvidia .blackwell import (
13
13
TensorMemoryLayout ,
14
14
allocate_tensor_memory ,
15
+ get_tmem_32x32b_reg_layout ,
15
16
tensor_memory_descriptor ,
16
17
tma ,
17
18
mbarrier ,
24
25
# ===-----------------------------------------------------------------------===#
25
26
26
27
27
- @gl .constexpr_function
28
- def get_tmem_32x32b_reg_layout (instr_shape , shape , num_warps ):
29
- assert len (shape ) == 2 , "expected a 2D tensor"
30
- assert num_warps in [4 , 8 ], "expected 4 or 8 warps"
31
- M , N , _ = instr_shape
32
-
33
- blocks_per_tile = [shape [0 ] // M , shape [1 ] // N ]
34
- num_blocks = blocks_per_tile [0 ] * blocks_per_tile [1 ]
35
-
36
- num_warp_groups = num_warps // 4
37
- if M == 64 :
38
- threads_per_warp = [16 , 2 ]
39
- if num_blocks == 1 :
40
- size_per_thread = [1 , N // (num_warp_groups * 2 )]
41
- warps_per_cta = [4 , num_warp_groups ]
42
- else :
43
- size_per_thread = [1 , N // 2 ]
44
- warps_per_cta = [4 * min (blocks_per_tile [0 ], num_warp_groups )]
45
- warps_per_cta .append (triton .cdiv (num_warp_groups , warps_per_cta [0 ] // 4 ))
46
- else :
47
- if shape [0 ] > 128 :
48
- size_per_thread = [1 , N ]
49
- threads_per_warp = [32 , 1 ]
50
- warps_per_cta = [4 * num_warp_groups , 1 ]
51
- else :
52
- size_per_thread = [1 , N // num_warp_groups ]
53
- threads_per_warp = [32 , 1 ]
54
- warps_per_cta = [4 , num_warp_groups ]
55
- return gl .BlockedLayout (
56
- size_per_thread = size_per_thread ,
57
- threads_per_warp = threads_per_warp ,
58
- warps_per_cta = warps_per_cta ,
59
- order = [0 , 1 ],
60
- )
61
-
62
-
63
28
@gl .constexpr_function
64
29
def get_mma_instr_shape (shape , element_ty ):
65
30
m = 128 if shape [0 ] >= 128 else 64
@@ -71,7 +36,7 @@ def get_mma_instr_shape(shape, element_ty):
71
36
@gl .constexpr_function
72
37
def get_mma_reg_layout (shape , num_warps , dtype = gl .float32 ):
73
38
instr_shape = get_mma_instr_shape (shape , dtype )
74
- return get_tmem_32x32b_reg_layout (instr_shape , shape , num_warps )
39
+ return get_tmem_32x32b_reg_layout (* instr_shape [: 2 ] , shape , num_warps )
75
40
76
41
77
42
# ===-----------------------------------------------------------------------===#
@@ -288,10 +253,12 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
288
253
self .o_tmem_layout = gl .constexpr (TensorMemoryLayout ((o_instr_shape [0 ], o_instr_shape [1 ]), unpacked = True ))
289
254
self .p_tmem_layout = gl .constexpr (TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), unpacked = False ))
290
255
291
- self .qk_layout = gl .constexpr (get_tmem_32x32b_reg_layout (qk_instr_shape , self .qk_shape , self .num_warps ))
292
- self .o_layout = gl .constexpr (get_tmem_32x32b_reg_layout (o_instr_shape , self .o_shape , self .num_warps ))
256
+ self .qk_layout = gl .constexpr (
257
+ get_tmem_32x32b_reg_layout (qk_instr_shape [0 ], qk_instr_shape [0 ], self .qk_shape , self .num_warps ))
258
+ self .o_layout = gl .constexpr (
259
+ get_tmem_32x32b_reg_layout (o_instr_shape [0 ], o_instr_shape [1 ], self .o_shape , self .num_warps ))
293
260
self .o_splitn_layout = gl .constexpr (
294
- get_tmem_32x32b_reg_layout (( o_instr_shape [0 ], o_instr_shape [1 ] // self .SPLIT_D_FACTOR , o_instr_shape [ 2 ]) ,
261
+ get_tmem_32x32b_reg_layout (o_instr_shape [0 ], o_instr_shape [1 ] // self .SPLIT_D_FACTOR ,
295
262
(self .o_shape [0 ], self .o_shape [1 ] // self .SPLIT_D_FACTOR ), self .num_warps ))
296
263
self .alpha_2d_layout = gl .constexpr (gl .BlockedLayout ([1 , 1 ], [32 , 1 ], [self .num_warps , 1 ], [0 , 1 ]))
297
264
0 commit comments