1212from triton .experimental .gluon .language .nvidia .blackwell import (
1313 TensorMemoryLayout ,
1414 allocate_tensor_memory ,
15+ get_tmem_32x32b_reg_layout ,
1516 tensor_memory_descriptor ,
1617 tma ,
1718 mbarrier ,
2425# ===-----------------------------------------------------------------------===#
2526
2627
27- @gl .constexpr_function
28- def get_tmem_32x32b_reg_layout (instr_shape , shape , num_warps ):
29- assert len (shape ) == 2 , "expected a 2D tensor"
30- assert num_warps in [4 , 8 ], "expected 4 or 8 warps"
31- M , N , _ = instr_shape
32-
33- blocks_per_tile = [shape [0 ] // M , shape [1 ] // N ]
34- num_blocks = blocks_per_tile [0 ] * blocks_per_tile [1 ]
35-
36- num_warp_groups = num_warps // 4
37- if M == 64 :
38- threads_per_warp = [16 , 2 ]
39- if num_blocks == 1 :
40- size_per_thread = [1 , N // (num_warp_groups * 2 )]
41- warps_per_cta = [4 , num_warp_groups ]
42- else :
43- size_per_thread = [1 , N // 2 ]
44- warps_per_cta = [4 * min (blocks_per_tile [0 ], num_warp_groups )]
45- warps_per_cta .append (triton .cdiv (num_warp_groups , warps_per_cta [0 ] // 4 ))
46- else :
47- if shape [0 ] > 128 :
48- size_per_thread = [1 , N ]
49- threads_per_warp = [32 , 1 ]
50- warps_per_cta = [4 * num_warp_groups , 1 ]
51- else :
52- size_per_thread = [1 , N // num_warp_groups ]
53- threads_per_warp = [32 , 1 ]
54- warps_per_cta = [4 , num_warp_groups ]
55- return gl .BlockedLayout (
56- size_per_thread = size_per_thread ,
57- threads_per_warp = threads_per_warp ,
58- warps_per_cta = warps_per_cta ,
59- order = [0 , 1 ],
60- )
61-
62-
6328@gl .constexpr_function
6429def get_mma_instr_shape (shape , element_ty ):
6530 m = 128 if shape [0 ] >= 128 else 64
@@ -71,7 +36,7 @@ def get_mma_instr_shape(shape, element_ty):
7136@gl .constexpr_function
7237def get_mma_reg_layout (shape , num_warps , dtype = gl .float32 ):
7338 instr_shape = get_mma_instr_shape (shape , dtype )
74- return get_tmem_32x32b_reg_layout (instr_shape , shape , num_warps )
39+ return get_tmem_32x32b_reg_layout (* instr_shape [: 2 ] , shape , num_warps )
7540
7641
7742# ===-----------------------------------------------------------------------===#
@@ -288,10 +253,12 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
288253 self .o_tmem_layout = gl .constexpr (TensorMemoryLayout ((o_instr_shape [0 ], o_instr_shape [1 ]), unpacked = True ))
289254 self .p_tmem_layout = gl .constexpr (TensorMemoryLayout ((qk_instr_shape [0 ], qk_instr_shape [1 ]), unpacked = False ))
290255
291- self .qk_layout = gl .constexpr (get_tmem_32x32b_reg_layout (qk_instr_shape , self .qk_shape , self .num_warps ))
292- self .o_layout = gl .constexpr (get_tmem_32x32b_reg_layout (o_instr_shape , self .o_shape , self .num_warps ))
256+ self .qk_layout = gl .constexpr (
257+ get_tmem_32x32b_reg_layout (qk_instr_shape [0 ], qk_instr_shape [0 ], self .qk_shape , self .num_warps ))
258+ self .o_layout = gl .constexpr (
259+ get_tmem_32x32b_reg_layout (o_instr_shape [0 ], o_instr_shape [1 ], self .o_shape , self .num_warps ))
293260 self .o_splitn_layout = gl .constexpr (
294- get_tmem_32x32b_reg_layout (( o_instr_shape [0 ], o_instr_shape [1 ] // self .SPLIT_D_FACTOR , o_instr_shape [ 2 ]) ,
261+ get_tmem_32x32b_reg_layout (o_instr_shape [0 ], o_instr_shape [1 ] // self .SPLIT_D_FACTOR ,
295262 (self .o_shape [0 ], self .o_shape [1 ] // self .SPLIT_D_FACTOR ), self .num_warps ))
296263 self .alpha_2d_layout = gl .constexpr (gl .BlockedLayout ([1 , 1 ], [32 , 1 ], [self .num_warps , 1 ], [0 , 1 ]))
297264
0 commit comments