Skip to content

Commit 0b1cf48

Browse files
authored
[BENCH] Routing improvements (#7369)
This improves end-to-end routing performance by about 30%: - fp32 logits: 22.8us --> 18.0us - fp16 logits: 17.3us --> 12.2us by a combination of several optimisations, including liberally fusing kernels in order to reduce the total number of launches from 7 to 4. Now, we only have four obligatory kernel launches: - `_topk_forward` - `_sum_bitmatrix_rows` - `_combined_routing_memset` - `_combined_routing_compute` although in an expert-sharding world there are extra kernels inserted between `_topk_forward` and `_sum_bitmatrix_rows` in order to mutate the bitmatrix and update the other data structures produced by `_topk_forward`.
1 parent 9af26ee commit 0b1cf48

File tree

4 files changed

+157
-79
lines changed

4 files changed

+157
-79
lines changed

python/triton_kernels/tests/test_routing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from triton_kernels.testing import assert_equal
66

77

8-
def init_data(n_tokens, n_expts_tot, dtype=torch.float32, device="cuda"):
8+
def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"):
99
logits = torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device, requires_grad=True)
1010
return logits
1111

@@ -26,7 +26,7 @@ def test_op(n_tokens_pad, n_tokens_raw, n_expts_tot, n_expts_act, sm_first, use_
2626
else:
2727
n_routing_rows = torch.tensor([n_tokens_raw], dtype=torch.int32, device=device)
2828
n_gates_raw = n_tokens_raw * n_expts_act
29-
tri_logits = init_data(n_tokens_pad, n_expts_tot, device=device).detach()
29+
tri_logits = init_data(n_tokens_pad, n_expts_tot, device=device, dtype=torch.float32).detach()
3030
tri_logits[n_tokens_raw:, :] = float("inf") # should not be used
3131
tri_logits = tri_logits.requires_grad_(True)
3232
ref_logits = tri_logits.clone().detach().requires_grad_(True)

python/triton_kernels/triton_kernels/routing.py

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import torch
22
import triton
33
from dataclasses import dataclass, field
4-
from .routing_details._routing_compute import _routing_memset_indx
5-
from .routing_details._routing_compute import _routing_compute_indx_offs
6-
from .routing_details._routing_compute import _routing_compute_indx
4+
from .routing_details._routing_compute import _combined_routing_compute
5+
from .routing_details._routing_compute import _combined_routing_memset
76
from .routing_details._routing_compute import _routing_clear_bitmatrix
87
from .routing_details._expt_data import _expt_data_memset
98
from .routing_details._expt_data import _expt_data_compute
@@ -115,32 +114,42 @@ def forward(ctx, expt_scal, expt_indx, bitmatrix):
115114
topk_indx = combined_indx[:n_gates_pad]
116115
gate_indx = combined_indx[n_gates_pad:]
117116
gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
118-
_routing_memset_indx[(cdiv(n_gates_pad * 2, MEMSET_BLOCK) + 1, )](
117+
118+
token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1a, blocks2a, MEMSET_BLOCK_A, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
119+
hist, n_expts_tot, n_gates_pad)
120+
121+
blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
122+
blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M)
123+
124+
_combined_routing_memset[(blocks1a + blocks1b, )](
119125
combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, #
120-
expt_offs, hist.shape[0], BLOCK_N=512 #
121-
)
122-
_routing_compute_indx_offs[(n_expts_tot, )](
123-
expt_offs, partial_hist, # inputs
126+
expt_offs, hist.shape[0], n_expts_tot, partial_hist, # inputs
124127
partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs
125-
BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
128+
token_offs_combined, token_offs_combined.stride(0), #
129+
blocks1a, block_pid_map, #
130+
block_m_log2_start, SIZES=block_m_num, BLOCK_A=MEMSET_BLOCK_A, # optimization parameters
131+
BLOCK_N=512, BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters
126132
)
133+
127134
indx_offs = partial_hist
128-
_routing_compute_indx[(cdiv(n_tokens_pad, HIST_BLOCK_M), )](
135+
136+
_combined_routing_compute[(blocks2a + blocks2b, )](
129137
topk_indx, gate_indx, gate_scal, # outputs
130138
expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), # inputs
131-
n_tokens_pad, n_tokens_raw, # input shape
132-
BLOCK_M=HIST_BLOCK_M, # tunable parameters
133-
N_EXPTS_ACT=n_expts_act, # constants
134-
num_warps=1 if HIST_BLOCK_M * n_expts_act // 32 < 4 else 4 #
139+
expt_offs, n_tokens_pad, n_tokens_raw, # input shape
140+
HIST_BLOCK_M, n_expts_act, # constants
141+
hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
142+
block_m_log2_start, block_m_num, HIST2_BLOCK_M, blocks2a, # etc.
135143
)
144+
136145
ctx.n_tokens_raw = n_tokens_raw
137146
ctx.n_tokens_pad = n_tokens_pad
138147
ctx.n_expts_act = n_expts_act
139148
ctx.save_for_backward(gate_indx)
140-
return hist, topk_indx, gate_indx, gate_scal
149+
return hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map
141150

142151
@staticmethod
143-
def backward(ctx, _0, _1, _2, dgate_scal):
152+
def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5):
144153
(gate_indx, ) = ctx.saved_tensors
145154
dgate_scal = dgate_scal[gate_indx]
146155
dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act)
@@ -193,16 +202,17 @@ def log2_power_of_two(x):
193202
return x.bit_length() - 1
194203

195204

196-
def compute_expt_data(expt_hist, n_expts_tot, n_gates):
197-
if expt_hist is None:
198-
return ExptData(None, None, None, None)
199-
MEMSET_BLOCK = 128
205+
block_m_log2_start = 4
206+
207+
208+
def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates):
209+
210+
MEMSET_BLOCK = 512
200211
HIST2_BLOCK_M = 512
201212
device = expt_hist.device
202213
n_expts_tot = n_expts_tot
203214
cdiv = triton.cdiv
204215
# block_ms are all powers-of-two between 16 and 128 (inclusive)
205-
block_m_log2_start = 4
206216
block_m_log2_end = 9 if is_hip() else 8
207217
block_m_num = block_m_log2_end - block_m_log2_start
208218
if n_gates <= n_expts_tot:
@@ -212,26 +222,53 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
212222
# allocate memory
213223
pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK
214224
dtype = torch.int32
215-
token_offs_raw = torch.empty((n_expts_tot + 1, ), dtype=dtype, device=device)
216-
token_offs_pad = torch.empty((block_m_num, pad(n_expts_tot + 1)), dtype=dtype, device=device)
225+
226+
token_offs_combined = torch.empty((block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device)
227+
228+
token_offs_raw = token_offs_combined[0][:n_expts_tot + 1]
229+
token_offs_pad = token_offs_combined[1:]
230+
217231
block_pid_map = torch.empty((block_m_num, pad(max_n_tiles)), dtype=dtype, device=device)
232+
memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK # exact division
218233
# compute outputs
219234
token_offs_pad = token_offs_pad[:, :n_expts_tot + 1]
220235
block_pid_map = block_pid_map[:, :max_n_tiles]
221-
memset_grid = cdiv(block_pid_map.shape[1], MEMSET_BLOCK) + 1
222-
_expt_data_memset[(memset_grid, block_m_num)](
223-
expt_hist, n_expts_tot, token_offs_raw, #
224-
token_offs_pad, token_offs_pad.stride(0), #
225-
block_pid_map, block_pid_map.stride(0), #
226-
block_m_log2_start, BLOCK=MEMSET_BLOCK, # optimization parameters
227-
num_warps=1)
228-
_expt_data_compute[(n_expts_tot, block_m_num)](
236+
237+
blocks1 = memset_grid + block_m_num + 1
238+
blocks2 = n_expts_tot * block_m_num
239+
240+
return token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num
241+
242+
243+
def _unpack_into_dict(x):
244+
245+
block_m_log2_end = block_m_log2_start + x.shape[0]
246+
x = {2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))}
247+
return x
248+
249+
250+
def compute_expt_data(expt_hist, n_expts_tot, n_gates):
251+
252+
if expt_hist is None:
253+
return ExptData(None, None, None, None)
254+
255+
# this just computes the kernel arguments:
256+
token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal(
257+
expt_hist, n_expts_tot, n_gates)
258+
259+
_expt_data_memset[(blocks1, )](
260+
expt_hist, n_expts_tot, #
261+
token_offs_combined, token_offs_combined.stride(0), #
262+
block_pid_map, #
263+
block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, # optimization parameters
264+
num_warps=4)
265+
_expt_data_compute[(blocks2, )](
229266
expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs
230-
block_m_log2_start, BLOCK=HIST2_BLOCK_M, # optimization parameters
267+
block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, # optimization parameters
231268
num_warps=4)
232-
# unpack into datastructure
233-
token_offs_pad = {2**j: token_offs_pad[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))}
234-
block_pid_map = {2**j: block_pid_map[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))}
269+
270+
token_offs_pad = _unpack_into_dict(token_offs_pad)
271+
block_pid_map = _unpack_into_dict(block_pid_map)
235272
return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map)
236273

237274

@@ -249,12 +286,18 @@ def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1,
249286
# mutate bitmatrix
250287
if simulated_ep > 1:
251288
expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, simulated_ep)
252-
hist, topk_indx, gate_indx, gate_scal = sort_tokens(expt_scal, expt_indx, bitmatrix)
289+
hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens(
290+
expt_scal, expt_indx, bitmatrix)
291+
292+
token_offs_pad = _unpack_into_dict(token_offs_pad)
293+
block_pid_map = _unpack_into_dict(block_pid_map)
294+
expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map)
295+
253296
# pack the matmul data structure
254297
n_expts_tot = logits.shape[-1] // simulated_ep
255298
gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx)
256299
scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx)
257-
expt_data = compute_expt_data(hist, n_expts_tot, topk_indx.numel())
300+
258301
return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx
259302

260303

python/triton_kernels/triton_kernels/routing_details/_expt_data.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,50 +8,45 @@ def _cdiv_pow2(n, log2_k):
88

99

1010
@triton.jit
11-
def _expt_data_memset(Hist, n_expts_tot, MDTokStarts, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_infos_stridem,
12-
first_tile_dim_log2, BLOCK: tl.constexpr):
13-
pid_n = tl.program_id(0)
14-
pid_m = tl.program_id(1)
11+
def _expt_data_memset(Hist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2,
12+
SIZES: tl.constexpr, BLOCK: tl.constexpr):
1513

16-
tile_dim_log2 = first_tile_dim_log2 + pid_m
17-
# if pid == 0 - initialize cumsums
18-
if pid_n == 0:
19-
MDTileStarts += pid_m * tile_starts_stridem
14+
pid = tl.program_id(0)
2015

21-
x_tok = tl.zeros([BLOCK], dtype=MDTokStarts.dtype.element_ty)
22-
x_tile = tl.zeros([BLOCK], dtype=MDTileStarts.dtype.element_ty)
16+
if pid <= SIZES:
2317

24-
Tok_ptrs = MDTokStarts + tl.arange(0, BLOCK)
25-
Tile_ptrs = MDTileStarts + tl.arange(0, BLOCK)
18+
MDStarts += pid * tile_starts_stridem
19+
x_tile = tl.zeros([BLOCK], dtype=MDStarts.dtype.element_ty)
20+
Tile_ptrs = MDStarts + tl.arange(0, BLOCK)
21+
tile_dim_log2 = tl.where(pid == 0, 0, pid + first_tile_dim_log2 - 1)
2622

2723
for i in range(0, n_expts_tot + 1, BLOCK):
24+
2825
offs_n = tl.arange(0, BLOCK) + i
2926
mask_n0 = offs_n < n_expts_tot
30-
mask_n1 = offs_n < n_expts_tot + 1
3127
hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0)
3228
hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2)
33-
tok_starts = tl.cumsum(hist_tok, 0) + x_tok
34-
x_tok += tl.sum(hist_tok, 0).to(MDTokStarts.dtype.element_ty)
35-
tile_starts = tl.cumsum(hist_tile, 0) + x_tile
36-
x_tile += tl.sum(hist_tile, 0).to(MDTileStarts.dtype.element_ty)
3729

38-
tl.store(Tok_ptrs, tok_starts - hist_tok, mask=mask_n1)
39-
tl.store(Tile_ptrs, tile_starts - hist_tile, mask=mask_n1)
40-
41-
Tok_ptrs += BLOCK
30+
tile_starts = tl.cumsum(hist_tile, 0) + x_tile
31+
x_tile += tl.sum(hist_tile, 0).to(MDStarts.dtype.element_ty)
32+
tl.store(Tile_ptrs, tile_starts - hist_tile)
4233
Tile_ptrs += BLOCK
4334

4435
else:
45-
MDTileInfo += pid_m * tile_infos_stridem
46-
TileInfoOut = MDTileInfo + (pid_n - 1) * BLOCK + tl.arange(0, BLOCK)
36+
37+
pid -= (SIZES + 1)
38+
TileInfoOut = MDTileInfo + pid * BLOCK + tl.arange(0, BLOCK)
4739
tl.store(TileInfoOut, 0xffffffff)
4840

4941

5042
@triton.jit
5143
def _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
52-
BLOCK: tl.constexpr):
53-
expt_id = tl.program_id(0)
54-
buff_id = tl.program_id(1)
44+
SIZES: tl.constexpr, BLOCK: tl.constexpr):
45+
46+
pid = tl.program_id(0)
47+
48+
expt_id = pid // SIZES
49+
buff_id = pid % SIZES
5550

5651
MDTileStarts += buff_id * tile_starts_stridem
5752
MDTileInfo += buff_id * tile_info_stridem
@@ -62,7 +57,7 @@ def _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile
6257

6358
tile_off = tl.load(MDTileStarts + expt_id)
6459
MDTileInfo += tile_off
65-
# MDTileInfo += tl.load(MDTilesStart + expt_id)
60+
6661
for block_off in range(0, n_blocks, BLOCK):
6762
block_offs = block_off + tl.arange(0, BLOCK)
6863
data = (block_offs << 16) + expt_id

python/triton_kernels/triton_kernels/routing_details/_routing_compute.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import triton
22
import triton.language as tl
33

4+
from ._expt_data import _expt_data_compute, _expt_data_memset
5+
46

57
@triton.jit
68
def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histogram
@@ -18,13 +20,10 @@ def _routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, # histog
1820

1921

2022
@triton.jit
21-
def _routing_compute_indx_offs(TokensStart, PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr):
22-
expt_id = tl.program_id(0)
23+
def _routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M: tl.constexpr, expt_id):
2324
offs_m = tl.arange(0, BLOCK_M)
24-
# initialize first row of the output
25-
start = tl.load(TokensStart + expt_id)
2625
# iterate over input data
27-
curr_sum = start
26+
curr_sum = 0
2827
for _ in range(0, shape_pm, BLOCK_M):
2928
offs = offs_m * stride_pm + expt_id * stride_pn
3029
curr = tl.load(PartialHist + offs, mask=offs_m < shape_pm)
@@ -47,10 +46,10 @@ def _keyed_add(x, y):
4746

4847

4948
@triton.jit
50-
def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, stride_pn,
51-
n_tokens_pad, NTokensRaw, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr):
49+
def _routing_compute_indx(pid_m, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm,
50+
stride_pn, TokensStart, n_tokens_pad, NTokensRaw, BLOCK_M: tl.constexpr,
51+
N_EXPTS_ACT: tl.constexpr):
5252

53-
pid_m = tl.program_id(0)
5453
n_tokens = n_tokens_pad
5554
if NTokensRaw is not None:
5655
n_tokens = tl.load(NTokensRaw)
@@ -75,14 +74,31 @@ def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx,
7574
expts_and_inclusive_run_lengths = tl.associative_scan(x, 0, _keyed_add)
7675
exclusive_run_lengths = (expts_and_inclusive_run_lengths - 1) & 0xffff
7776

78-
gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=(expert != 0xffff))
77+
gates = tl.load(PartialOffs + pid_m * stride_pm + expert * stride_pn, mask=mask)
78+
gates += tl.load(TokensStart + expert, mask=mask)
7979
gates += exclusive_run_lengths
8080

8181
tl.store(ScatterIndx + offs, gates, mask=mask)
8282
tl.store(GatherIndx + gates, offs, mask=mask)
8383
tl.store(GateScal + gates, gate_scal, mask=mask)
8484

8585

86+
@triton.jit
87+
def _combined_routing_compute(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, stride_pn,
88+
TokensStart, n_tokens_pad, NTokensRaw, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr,
89+
Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem,
90+
first_tile_dim_log2, SIZES: tl.constexpr, BLOCK: tl.constexpr, blocks2a):
91+
92+
pid = tl.program_id(0)
93+
if pid < blocks2a:
94+
_expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2,
95+
SIZES, BLOCK)
96+
else:
97+
pid -= blocks2a
98+
_routing_compute_indx(pid, GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm,
99+
stride_pn, TokensStart, n_tokens_pad, NTokensRaw, BLOCK_M, N_EXPTS_ACT)
100+
101+
86102
@triton.jit
87103
def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff, BLOCK_N: tl.constexpr):
88104
pid_m = tl.program_id(0)
@@ -98,13 +114,37 @@ def _routing_clear_bitmatrix(Bitmatrix, stride_bm, stride_bn, shape_bn, cutoff,
98114

99115

100116
@triton.jit
101-
def _routing_memset_indx(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size,
102-
BLOCK_N: tl.constexpr):
117+
def _combined_routing_memset(Indx, size, sentinel, BLOCK: tl.constexpr, ExpertHist, FinalExpertOffs, hist_size,
118+
n_expts_tot, PartialHist, shape_pm, stride_pm, stride_pn, MDStarts, tile_starts_stridem,
119+
blocks1a, MDTileInfo, first_tile_dim_log2, SIZES: tl.constexpr, BLOCK_A: tl.constexpr,
120+
BLOCK_N: tl.constexpr, BLOCK_M: tl.constexpr):
121+
"""
122+
This kernel essentially combines 6 different pieces of functionality,
123+
statically branching on the value of tl.program_id(0) to decide which
124+
codepath to take.
125+
126+
pid == 0: create the token cumsum
127+
1 <= pid <= SIZES: create a tile cumsum
128+
SIZES < pid < blocks1a: initialise MDTileInfo to 0xffffffff
129+
blocks1a <= pid < blocks1a + n_expts_tot: compute_indx_offs
130+
pid == blocks1a + n_expts_tot: compute_expt_offs
131+
pid > blocks1a + n_expts_tot: initialise Indx to sentinel
132+
133+
As each of these is a relatively trivial workload, launching them from
134+
this single trampoline is beneficial as they can execute on different
135+
streaming multiprocesses in parallel.
136+
"""
137+
103138
pid = tl.program_id(0)
104139

105-
if pid == 0:
140+
if pid < blocks1a:
141+
_expt_data_memset(ExpertHist, n_expts_tot, MDStarts, tile_starts_stridem, MDTileInfo, first_tile_dim_log2,
142+
SIZES, BLOCK_A)
143+
elif pid == n_expts_tot + blocks1a:
106144
_routing_compute_expt_offs(ExpertHist, FinalExpertOffs, hist_size, BLOCK_N)
145+
elif pid < n_expts_tot + blocks1a:
146+
_routing_compute_indx_offs(PartialHist, shape_pm, stride_pm, stride_pn, BLOCK_M, pid - blocks1a)
107147
else:
108-
offs = (pid - 1) * BLOCK + tl.arange(0, BLOCK)
148+
offs = (pid - n_expts_tot - blocks1a - 1) * BLOCK + tl.arange(0, BLOCK)
109149
mask = offs < size
110150
tl.store(Indx + offs, sentinel, mask=mask)

0 commit comments

Comments
 (0)