@@ -13,10 +13,9 @@ def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histog
1313 offs_n = i * BLOCK_N + tl .arange (0 , BLOCK_N )
1414 mask_n = offs_n < hist_size
1515 hist2 = tl .load (ExpertHist + offs_n , mask = mask_n )
16- tok_starts = tl .cumsum (hist2 , 0 ) + x
16+ tok_starts = tl .cumsum (hist2 , 0 ) - hist2 + x
1717 x += tl .sum (hist2 , 0 )
18- tl .store (FinalExpertOffs , 0 )
19- tl .store (FinalExpertOffs + 1 + offs_n , tok_starts , mask = mask_n )
18+ tl .store (FinalExpertOffs + offs_n , tok_starts , mask = mask_n )
2019 offs_n += BLOCK_N
2120
2221
@@ -52,51 +51,33 @@ def _keyed_add(x, y):
5251
5352
5453@triton .jit
55- def _count_previous (x ):
56- """
57- Input x : uint16[..., N]
58- Output y : uint32[..., N]
59- semantics : y[..., i] = sum_j((x[..., j] == x[..., i]) & (j < i))
60- credits: @apgoucher
61- """
54+ def _routing_compute_indx (GatherIndx , ScatterIndx , GateScal , ExptScal , ExptIndx , PartialOffs , stride_pm , n_gates ,
55+ BLOCK_M : tl .constexpr , N_EXPTS_ACT : tl .constexpr ):
6256
63- BLOCK_N : tl .constexpr = x .shape [- 1 ] # summation axis
64- BATCHES : tl .constexpr = x .numel // BLOCK_N # number of batches
57+ pid_m = tl .program_id (0 )
6558
66- # reduce to two-dimensional case:
67- y = tl .reshape (x , [BATCHES , BLOCK_N ]).to (tl .uint32 )
59+ tl .static_assert (N_EXPTS_ACT * BLOCK_M <= 32768 )
6860
69- tl .static_assert (BLOCK_N <= 32768 , "compute_run_lengths requires axis to have length <= 32768" )
61+ local_offs = tl .arange (0 , N_EXPTS_ACT * BLOCK_M )
62+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + local_offs
63+ expert = tl .load (ExptIndx + offs , mask = (offs < n_gates ), other = - 1 ).to (tl .uint32 )
7064
71- # sort (expert, position) ordered pairs to perform an argsort:
72- kv_pairs = ((y << 16 ) | tl .arange (0 , BLOCK_N )[None , :]).to (tl .uint32 )
73- sorted_kv_pairs = tl .sort (kv_pairs , 1 )
65+ # stable-sort by expert ID:
66+ kv_pairs = ((expert << 16 ) | local_offs ).to (tl .uint32 )
67+ kv_pairs = tl .sort (kv_pairs , 0 )
68+ expert = kv_pairs >> 16
69+ offs = pid_m * BLOCK_M * N_EXPTS_ACT + (kv_pairs & 0xffff )
70+ mask = expert != 0xffff
71+ gate_scal = tl .load (ExptScal + offs , mask = mask )
7472
7573 # compute run lengths in expert-sorted order:
76- x = (sorted_kv_pairs & 0xffff0000 | 0x00000001 )
77- expts_and_inclusive_run_lengths = tl .associative_scan (x , 1 , _keyed_add )
74+ x = (kv_pairs & 0xffff0000 | 0x00000001 )
75+ expts_and_inclusive_run_lengths = tl .associative_scan (x , 0 , _keyed_add )
7876 exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1 ) & 0xffff
7977
80- # undo permutation by doing another sort
81- # TODO rewrite this when tl.scatter becomes available
82- kv_pairs = ((sorted_kv_pairs << 16 ) | exclusive_run_lengths ).to (tl .uint32 )
83- unsorted_run_lengths = tl .sort (kv_pairs ) & 0xffff
84-
85- res = tl .reshape (unsorted_run_lengths , x .shape )
86- return res
87-
78+ gates = tl .load (PartialOffs + pid_m * stride_pm + expert , mask = (expert != 0xffff ))
79+ gates += exclusive_run_lengths
8880
89- @triton .jit
90- def _routing_compute_indx (GatherIndx , ScatterIndx , GateScal , ExptScal , ExptIndx , PartialOffs , stride_pm , n_gates ,
91- BLOCK_M : tl .constexpr , N_EXPTS_ACT : tl .constexpr ):
92- pid_m = tl .program_id (0 )
93- offs = pid_m * BLOCK_M * N_EXPTS_ACT + tl .arange (0 , N_EXPTS_ACT * BLOCK_M )
94- mask = offs < n_gates
95- indx = tl .load (ExptIndx + offs , mask = mask )
96- mask = mask & (indx != - 1 )
97- gates = tl .load (PartialOffs + pid_m * stride_pm + indx , mask = mask )
98- gates += tl .reshape (_count_previous (indx ), [BLOCK_M * N_EXPTS_ACT ])
99- gate_scal = tl .load (ExptScal + offs , mask = mask )
10081 tl .store (ScatterIndx + offs , gates , mask = mask )
10182 tl .store (GatherIndx + gates , offs , mask = mask )
10283 tl .store (GateScal + gates , gate_scal , mask = mask )
@@ -117,15 +98,16 @@ def _routing_clear_bitmatrix(Bitmatrix, stride_bm, shape_bn, cutoff, BLOCK_N: tl
11798
11899
119100@triton .jit
120- def _routing_memset_indx (Indx0 , Indx1 , size , sentinel , BLOCK : tl .constexpr ):
101+ def _routing_memset_indx (Indx , size , sentinel , BLOCK : tl .constexpr , ExpertHist , FinalExpertOffs , hist_size ,
102+ BLOCK_N : tl .constexpr ):
121103 pid = tl .program_id (0 )
122- buf = tl . program_id ( 1 )
123- offs = pid * BLOCK + tl . arange ( 0 , BLOCK )
124- mask = offs < size
125- if buf == 0 :
126- tl . store ( Indx0 + offs , sentinel , mask = mask )
127- if buf == 1 :
128- tl .store (Indx1 + offs , sentinel , mask = mask )
104+
105+ if pid == 0 :
106+ _routing_compute_expt_offs ( ExpertHist , FinalExpertOffs , hist_size , BLOCK_N )
107+ else :
108+ offs = ( pid - 1 ) * BLOCK + tl . arange ( 0 , BLOCK )
109+ mask = offs < size
110+ tl .store (Indx + offs , sentinel , mask = mask )
129111
130112
131113@dataclass
@@ -204,22 +186,15 @@ def routing(logits, n_expts_act, expt_indx=None, simulated_ep=1):
204186 # perform compaction to update expt_scal / expt_indx
205187 hist , partial_hist = sum (bitmatrix , partials_block_size = HIST_BLOCK_M , dim = 0 )
206188 # scratchpad
207- expt_offs = torch .empty (n_expts_tot + 1 , dtype = torch .int32 , device = device )
189+ expt_offs = torch .empty (n_expts_tot , dtype = torch .int32 , device = device )
208190 indx_offs = torch .empty ((cdiv (n_tokens , HIST_BLOCK_M ), n_expts_tot ), dtype = torch .int32 , device = device )
191+ combined_indx = torch .empty (n_gates * 2 , dtype = torch .int32 , device = device )
209192 # output
210- topk_indx = torch . empty ( n_gates , dtype = torch . int32 , device = device )
211- gate_indx = torch . empty ( n_gates , dtype = torch . int32 , device = device )
193+ topk_indx = combined_indx [: n_gates ]
194+ gate_indx = combined_indx [ n_gates :]
212195 gate_scal = torch .empty (n_gates , dtype = logits .dtype , device = device )
213- _routing_memset_indx [(cdiv (n_gates , MEMSET_BLOCK ), 2 )](
214- topk_indx ,
215- gate_indx ,
216- n_gates ,
217- - 1 ,
218- BLOCK = MEMSET_BLOCK ,
219- )
220- _routing_compute_expt_offs [(1 , )](
221- hist , expt_offs , hist .shape [0 ], BLOCK_N = 512 # tunable parameters
222- )
196+ _routing_memset_indx [(cdiv (n_gates * 2 , MEMSET_BLOCK ) + 1 , )](combined_indx , n_gates * 2 , - 1 , MEMSET_BLOCK , hist ,
197+ expt_offs , hist .shape [0 ], BLOCK_N = 512 )
223198 _routing_compute_indx_offs [(n_expts_tot , )](
224199 expt_offs , partial_hist , # inputs
225200 indx_offs , partial_hist .shape [0 ], partial_hist .stride (0 ), # outputs
0 commit comments