@@ -14,41 +14,35 @@ class ExptData:
1414
1515
1616@triton .jit
17- def _memset_metadata (Metadata , metadata_size , BLOCK : tl .constexpr ):
17+ def _matmul_metadata_memset (Hist , n_expts_tot , MDHist , MDTokStarts , MDTileStarts , MDTileInfo , md_n_tiles ,
18+ BLOCK : tl .constexpr , TILE_DIM : tl .constexpr ):
1819 pid = tl .program_id (0 )
20+ # if pid == 0 - initialize cumsums
21+ if pid == 0 :
22+ x_tok = tl .zeros ([BLOCK ], dtype = MDTokStarts .dtype .element_ty )
23+ x_tile = tl .zeros ([BLOCK ], dtype = MDTileStarts .dtype .element_ty )
24+ tl .store (MDTokStarts , 0 )
25+ tl .store (MDTileStarts , 0 )
26+ for i in range (0 , n_expts_tot , BLOCK ):
27+ offs_n = tl .arange (0 , BLOCK ) + i
28+ mask = offs_n < n_expts_tot
29+ hist_tok = tl .load (Hist + offs_n , mask = mask )
30+ hist_tile = tl .cdiv (hist_tok , TILE_DIM )
31+ tok_starts = tl .cumsum (hist_tok , 0 ) + x_tok
32+ x_tok += tl .sum (hist_tok , 0 ).to (MDTokStarts .dtype .element_ty )
33+ tile_starts = tl .cumsum (hist_tile , 0 ) + x_tile
34+ x_tile += tl .sum (hist_tile , 0 ).to (MDTileStarts .dtype .element_ty )
35+ tl .store (MDHist + offs_n , hist_tok , mask = mask )
36+ tl .store (MDTokStarts + 1 + offs_n , tok_starts , mask = mask )
37+ tl .store (MDTileStarts + 1 + offs_n , tile_starts , mask = mask )
38+
39+ # initialize block data
1940 offs = pid * BLOCK + tl .arange (0 , BLOCK )
20- tl .store (Metadata + offs , 0xffffffff , mask = offs < metadata_size )
41+ tl .store (MDTileInfo + offs , 0xffffffff , mask = offs < md_n_tiles )
2142
2243
2344@triton .jit
24- def _compute_metadata_1 (Hist , n_expts_tot , MDHist , MDTokStarts , MDTileStarts , MDTileInfo , N_EXPTS_PAD : tl .constexpr ,
25- BLOCK : tl .constexpr , TILE_DIM : tl .constexpr ):
26-
27- BLOCK_N : tl .constexpr = 1024
28-
29- x_tok = tl .zeros ([BLOCK_N ], dtype = MDTokStarts .dtype .element_ty )
30- x_tile = tl .zeros ([BLOCK_N ], dtype = MDTileStarts .dtype .element_ty )
31-
32- tl .store (MDTokStarts , 0 )
33- tl .store (MDTileStarts , 0 )
34-
35- for i in range (0 , n_expts_tot , BLOCK_N ):
36- offs_n = tl .arange (0 , BLOCK_N ) + i
37- mask = offs_n < n_expts_tot
38- hist_tok = tl .load (Hist + offs_n , mask = mask )
39- hist_tile = tl .cdiv (hist_tok , TILE_DIM )
40- tok_starts = tl .cumsum (hist_tok , 0 ) + x_tok
41- x_tok += tl .sum (hist_tok , 0 )
42- tile_starts = tl .cumsum (hist_tile , 0 ) + x_tile
43- x_tile += tl .sum (hist_tile , 0 )
44- tl .store (MDHist + offs_n , hist_tok , mask = mask )
45- tl .store (MDTokStarts + 1 + offs_n , tok_starts , mask = mask )
46- tl .store (MDTileStarts + 1 + offs_n , tile_starts , mask = mask )
47-
48-
49- @triton .jit
50- def _compute_metadata_2 (Hist , n_expts_tot , MDHist , MDTokStarts , MDTileStarts , MDTileInfo , N_EXPTS_PAD : tl .constexpr ,
51- BLOCK : tl .constexpr , TILE_DIM : tl .constexpr ):
45+ def _matmul_metadata_compute (Hist , MDTileStarts , MDTileInfo , BLOCK : tl .constexpr , TILE_DIM : tl .constexpr ):
5246
5347 expt_id = tl .program_id (0 )
5448 n_tokens = tl .load (Hist + expt_id )
@@ -75,26 +69,21 @@ def compute_metadata(routing_data, n_rows, block_m):
7569 grid_m = n_rows
7670 else :
7771 grid_m = n_expts_tot - 1 - ((n_expts_tot - n_rows - 1 ) // block_m )
78- n_expts_pad = cdiv (n_expts_tot , 128 ) * 128
7972 metadata_size = 3 * n_expts_tot + 2 + grid_m
8073 metadata = torch .empty (metadata_size , dtype = torch .int32 , device = device )
8174 md_hist = metadata [:n_expts_tot ]
82- md_tok_starts = metadata [n_expts_tot :n_expts_tot * 2 + 1 ]
75+ md_offs = metadata [n_expts_tot :n_expts_tot * 2 + 1 ]
76+ md_offs_sum = metadata [3 * n_expts_tot + 2 - 1 ]
8377 md_tile_starts = metadata [n_expts_tot * 2 + 1 :n_expts_tot * 3 + 2 ]
8478 md_tile_infos = metadata [n_expts_tot * 3 + 2 :]
85- _memset_metadata [(cdiv (metadata_size , MEMSET_BLOCK ), )](
86- metadata , metadata_size , # inputs
87- BLOCK = MEMSET_BLOCK # optimization parameters
79+ _matmul_metadata_memset [(cdiv (metadata_size , MEMSET_BLOCK ), )](
80+ routing_data .expt_hist , n_expts_tot , md_hist , md_offs , md_tile_starts , md_tile_infos , md_tile_infos .shape [0 ],
81+ BLOCK = MEMSET_BLOCK , # optimization parameters
82+ TILE_DIM = block_m , # constants
83+ )
84+ _matmul_metadata_compute [(n_expts_tot , )](
85+ routing_data .expt_hist , md_tile_starts , md_tile_infos , # outputs
86+ BLOCK = HIST2_BLOCK_M , # optimization parameters
87+ TILE_DIM = block_m , # constants
8888 )
89- for kernel , num_blocks in [(_compute_metadata_1 , 1 ), (_compute_metadata_2 , n_expts_tot )]:
90- kernel [(num_blocks , )](
91- routing_data .expt_hist , n_expts_tot , # inputs
92- md_hist , md_tok_starts , md_tile_starts , md_tile_infos , # outputs
93- BLOCK = HIST2_BLOCK_M , # optimization parameters
94- N_EXPTS_PAD = n_expts_pad , TILE_DIM = block_m , # constants
95- )
96- hist = metadata [:n_expts_tot ]
97- offs = metadata [n_expts_tot :2 * n_expts_tot + 1 ]
98- offs_sum = metadata [3 * n_expts_tot + 2 - 1 ]
99- blocks = metadata [n_expts_tot + 2 * (n_expts_tot + 1 ):]
100- return ExptData (hist , offs , offs_sum , blocks , metadata )
89+ return ExptData (md_hist , md_offs , md_offs_sum , md_tile_infos , metadata )
0 commit comments