@@ -14,31 +14,49 @@ class ExptData:
14
14
15
15
16
16
@triton .jit
17
- def _matmul_metadata_memset (Hist , n_expts_tot , MDHist , MDTokStarts , MDTileStarts , MDTileInfo , md_n_tiles ,
18
- BLOCK : tl .constexpr , TILE_DIM : tl .constexpr ):
17
+ def _matmul_metadata_memset (Hist , n_expts_tot , MDTokStarts , MDTileStarts , MDTileInfo , BLOCK : tl . constexpr ,
18
+ TILE_DIM : tl .constexpr , extra_block : tl .constexpr ):
19
19
pid = tl .program_id (0 )
20
+
21
+ TileInfoOut = MDTileInfo + (pid - 1 ) * BLOCK + tl .arange (0 , BLOCK )
22
+
20
23
# if pid == 0 - initialize cumsums
21
24
if pid == 0 :
22
25
x_tok = tl .zeros ([BLOCK ], dtype = MDTokStarts .dtype .element_ty )
23
26
x_tile = tl .zeros ([BLOCK ], dtype = MDTileStarts .dtype .element_ty )
24
- tl .store (MDTokStarts , 0 )
25
- tl .store (MDTileStarts , 0 )
27
+
28
+ Tok_ptrs = MDTokStarts + tl .arange (0 , BLOCK )
29
+ Tile_ptrs = MDTileStarts + tl .arange (0 , BLOCK )
30
+
26
31
for i in range (0 , n_expts_tot , BLOCK ):
27
32
offs_n = tl .arange (0 , BLOCK ) + i
28
- mask = offs_n < n_expts_tot
29
- hist_tok = tl .load (Hist + offs_n , mask = mask )
33
+ if extra_block :
34
+ # we need an extra block at the end just to contain the final
35
+ # sum; this only happens if our total number of experts is an
36
+ # exact multiple of BLOCK, obviating the need for any masking
37
+ hist_tok = tl .load (Hist + offs_n )
38
+ else :
39
+ mask = offs_n < n_expts_tot
40
+ hist_tok = tl .load (Hist + offs_n , mask = mask , other = 0 )
30
41
hist_tile = tl .cdiv (hist_tok , TILE_DIM )
31
42
tok_starts = tl .cumsum (hist_tok , 0 ) + x_tok
32
43
x_tok += tl .sum (hist_tok , 0 ).to (MDTokStarts .dtype .element_ty )
33
44
tile_starts = tl .cumsum (hist_tile , 0 ) + x_tile
34
45
x_tile += tl .sum (hist_tile , 0 ).to (MDTileStarts .dtype .element_ty )
35
- tl .store (MDHist + offs_n , hist_tok , mask = mask )
36
- tl .store (MDTokStarts + 1 + offs_n , tok_starts , mask = mask )
37
- tl .store (MDTileStarts + 1 + offs_n , tile_starts , mask = mask )
38
46
39
- # initialize block data
40
- offs = pid * BLOCK + tl .arange (0 , BLOCK )
41
- tl .store (MDTileInfo + offs , 0xffffffff , mask = offs < md_n_tiles )
47
+ tl .store (Tok_ptrs , tok_starts - hist_tok )
48
+ tl .store (Tile_ptrs , tile_starts - hist_tile )
49
+
50
+ Tok_ptrs += BLOCK
51
+ Tile_ptrs += BLOCK
52
+
53
+ if extra_block :
54
+ tl .store (Tok_ptrs , x_tok )
55
+ tl .store (Tile_ptrs , x_tile )
56
+
57
+ else :
58
+
59
+ tl .store (TileInfoOut , 0xffffffff )
42
60
43
61
44
62
@triton .jit
@@ -60,7 +78,7 @@ def _matmul_metadata_compute(Hist, MDTileStarts, MDTileInfo, BLOCK: tl.constexpr
60
78
def compute_metadata (routing_data , n_rows , block_m ):
61
79
if routing_data .expt_hist is None :
62
80
return ExptData (None , None , None , None , None )
63
- MEMSET_BLOCK = 512
81
+ MEMSET_BLOCK = 128
64
82
HIST2_BLOCK_M = 512
65
83
device = routing_data .expt_hist .device
66
84
n_expts_tot = routing_data .n_expts_tot
@@ -69,21 +87,29 @@ def compute_metadata(routing_data, n_rows, block_m):
69
87
grid_m = n_rows
70
88
else :
71
89
grid_m = n_expts_tot - 1 - ((n_expts_tot - n_rows - 1 ) // block_m )
72
- metadata_size = 3 * n_expts_tot + 2 + grid_m
90
+
91
+ n_expts_pad = cdiv (n_expts_tot , MEMSET_BLOCK ) * MEMSET_BLOCK
92
+ pad2 = cdiv (n_expts_tot + 1 , MEMSET_BLOCK ) * MEMSET_BLOCK
93
+ extra_block = (n_expts_pad != pad2 )
94
+ pids = cdiv (grid_m , MEMSET_BLOCK ) + 1
95
+
96
+ metadata_size = n_expts_pad + 2 * pad2 + MEMSET_BLOCK * (pids - 1 )
97
+
73
98
metadata = torch .empty (metadata_size , dtype = torch .int32 , device = device )
74
- md_hist = metadata [:n_expts_tot ]
75
- md_offs = metadata [n_expts_tot :n_expts_tot * 2 + 1 ]
76
- md_offs_sum = metadata [3 * n_expts_tot + 2 - 1 ]
77
- md_tile_starts = metadata [n_expts_tot * 2 + 1 :n_expts_tot * 3 + 2 ]
78
- md_tile_infos = metadata [n_expts_tot * 3 + 2 :]
79
- _matmul_metadata_memset [(cdiv (metadata_size , MEMSET_BLOCK ), )](
80
- routing_data .expt_hist , n_expts_tot , md_hist , md_offs , md_tile_starts , md_tile_infos , md_tile_infos .shape [0 ],
99
+
100
+ md_hist = routing_data .expt_hist [:n_expts_tot ]
101
+ md_offs = metadata [:n_expts_tot + 1 ]
102
+ md_tile_starts = metadata [pad2 :][:n_expts_tot + 1 ]
103
+ md_offs_sum = md_tile_starts [- 1 ]
104
+ md_tile_infos = metadata [2 * pad2 :][:grid_m ]
105
+ _matmul_metadata_memset [(pids , )](
106
+ routing_data .expt_hist , n_expts_tot , md_offs , md_tile_starts , md_tile_infos ,
81
107
BLOCK = MEMSET_BLOCK , # optimization parameters
82
108
TILE_DIM = block_m , # constants
83
- )
109
+ extra_block = extra_block , num_warps = 1 )
84
110
_matmul_metadata_compute [(n_expts_tot , )](
85
111
routing_data .expt_hist , md_tile_starts , md_tile_infos , # outputs
86
112
BLOCK = HIST2_BLOCK_M , # optimization parameters
87
113
TILE_DIM = block_m , # constants
88
- )
114
+ num_warps = 4 )
89
115
return ExptData (md_hist , md_offs , md_offs_sum , md_tile_infos , metadata )
0 commit comments