Skip to content

Commit d5156d7

Browse files
authored
[KERNELS] Change routing code to avoid storage(). (#8357)
(Calls to storage() caused issues when using FakeTensor.)
1 parent 5201154 commit d5156d7

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

python/triton_kernels/triton_kernels/routing.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,15 @@ def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
128128
expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device)
129129
combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device)
130130
gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device)
131-
token_offs_combined = empty_aligned((block_m_num + 1, n_expts_tot + 1), torch.int32, device, MEMSET_BLOCK_A)
132-
block_pid_map = empty_aligned((block_m_num, max_n_tiles(n_expts_tot, n_gates_pad)), torch.int32, device,
133-
MEMSET_BLOCK_A)
131+
token_offs_combined, _ = empty_aligned((block_m_num + 1, n_expts_tot + 1), torch.int32, device, MEMSET_BLOCK_A)
132+
block_pid_map, block_pid_map_n_elts = empty_aligned((block_m_num, max_n_tiles(n_expts_tot, n_gates_pad)),
133+
torch.int32, device, MEMSET_BLOCK_A)
134134
# slice padded allocations
135135
combine_indx = combined_indx[:n_gates_pad]
136136
dispatch_indx = combined_indx[n_gates_pad:]
137137
token_offs_raw, token_offs_pad = token_offs_combined[0], token_offs_combined[1:]
138138

139139
# grid sizes
140-
block_pid_map_n_elts = block_pid_map.untyped_storage().size() // block_pid_map.dtype.itemsize
141140
blocks1a = exact_div(block_pid_map_n_elts, MEMSET_BLOCK_A) + token_offs_combined.shape[0]
142141
blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1
143142
blocks2a = n_expts_tot * token_offs_pad.shape[0]
@@ -198,7 +197,7 @@ def empty_aligned(shape, dtype, device, pad_size):
198197
pad = lambda x: cdiv(x, pad_size) * pad_size
199198
ret = torch.empty((*shape[:-1], pad(shape[-1])), dtype=dtype, device=device)
200199
ret_slices = (*[slice(None)] * (len(shape) - 1), slice(0, shape[-1]))
201-
return ret[ret_slices]
200+
return ret[ret_slices], ret.numel()
202201

203202

204203
def max_n_tiles(n_expts_tot, n_gates):
@@ -217,10 +216,11 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
217216
MEMSET_BLOCK = 512
218217
dtype = torch.int32
219218
device = expt_hist.device
220-
token_offs_combined = empty_aligned((block_m_num + 1, n_expts_tot + 1), dtype, device, MEMSET_BLOCK)
221-
block_pid_map = empty_aligned((block_m_num, max_n_tiles(n_expts_tot, n_gates)), dtype, device, MEMSET_BLOCK)
219+
token_offs_combined, _ = empty_aligned((block_m_num + 1, n_expts_tot + 1), dtype, device, MEMSET_BLOCK)
220+
block_pid_map, block_pid_map_size = empty_aligned((block_m_num, max_n_tiles(n_expts_tot, n_gates)), dtype, device,
221+
MEMSET_BLOCK)
222222
token_offs_raw, token_offs_pad = token_offs_combined[0], token_offs_combined[1:]
223-
n_memset_blocks = exact_div(block_pid_map.storage().size(), MEMSET_BLOCK)
223+
n_memset_blocks = exact_div(block_pid_map_size, MEMSET_BLOCK)
224224

225225
_expt_data_memset[(token_offs_combined.shape[0] + n_memset_blocks, )](
226226
expt_hist, n_expts_tot, #

0 commit comments

Comments
 (0)