@@ -128,16 +128,15 @@ def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix):
128128 expt_offs = torch .empty (n_expts_tot , dtype = torch .int32 , device = device )
129129 combined_indx = torch .empty (n_gates_pad * 2 , dtype = torch .int32 , device = device )
130130 gate_scal = torch .empty (n_gates_pad , dtype = dtype , device = device )
131- token_offs_combined = empty_aligned ((block_m_num + 1 , n_expts_tot + 1 ), torch .int32 , device , MEMSET_BLOCK_A )
132- block_pid_map = empty_aligned ((block_m_num , max_n_tiles (n_expts_tot , n_gates_pad )), torch . int32 , device ,
133- MEMSET_BLOCK_A )
131+ token_offs_combined , _ = empty_aligned ((block_m_num + 1 , n_expts_tot + 1 ), torch .int32 , device , MEMSET_BLOCK_A )
132+ block_pid_map , block_pid_map_n_elts = empty_aligned ((block_m_num , max_n_tiles (n_expts_tot , n_gates_pad )),
133+ torch . int32 , device , MEMSET_BLOCK_A )
134134 # slice padded allocations
135135 combine_indx = combined_indx [:n_gates_pad ]
136136 dispatch_indx = combined_indx [n_gates_pad :]
137137 token_offs_raw , token_offs_pad = token_offs_combined [0 ], token_offs_combined [1 :]
138138
139139 # grid sizes
140- block_pid_map_n_elts = block_pid_map .untyped_storage ().size () // block_pid_map .dtype .itemsize
141140 blocks1a = exact_div (block_pid_map_n_elts , MEMSET_BLOCK_A ) + token_offs_combined .shape [0 ]
142141 blocks1b = cdiv (n_gates_pad * 2 , MEMSET_BLOCK ) + n_expts_tot + 1
143142 blocks2a = n_expts_tot * token_offs_pad .shape [0 ]
@@ -198,7 +197,7 @@ def empty_aligned(shape, dtype, device, pad_size):
198197 pad = lambda x : cdiv (x , pad_size ) * pad_size
199198 ret = torch .empty ((* shape [:- 1 ], pad (shape [- 1 ])), dtype = dtype , device = device )
200199 ret_slices = (* [slice (None )] * (len (shape ) - 1 ), slice (0 , shape [- 1 ]))
201- return ret [ret_slices ]
200+ return ret [ret_slices ], ret . numel ()
202201
203202
204203def max_n_tiles (n_expts_tot , n_gates ):
@@ -217,10 +216,11 @@ def compute_expt_data(expt_hist, n_expts_tot, n_gates):
217216 MEMSET_BLOCK = 512
218217 dtype = torch .int32
219218 device = expt_hist .device
220- token_offs_combined = empty_aligned ((block_m_num + 1 , n_expts_tot + 1 ), dtype , device , MEMSET_BLOCK )
221- block_pid_map = empty_aligned ((block_m_num , max_n_tiles (n_expts_tot , n_gates )), dtype , device , MEMSET_BLOCK )
219+ token_offs_combined , _ = empty_aligned ((block_m_num + 1 , n_expts_tot + 1 ), dtype , device , MEMSET_BLOCK )
220+ block_pid_map , block_pid_map_size = empty_aligned ((block_m_num , max_n_tiles (n_expts_tot , n_gates )), dtype , device ,
221+ MEMSET_BLOCK )
222222 token_offs_raw , token_offs_pad = token_offs_combined [0 ], token_offs_combined [1 :]
223- n_memset_blocks = exact_div (block_pid_map . storage (). size () , MEMSET_BLOCK )
223+ n_memset_blocks = exact_div (block_pid_map_size , MEMSET_BLOCK )
224224
225225 _expt_data_memset [(token_offs_combined .shape [0 ] + n_memset_blocks , )](
226226 expt_hist , n_expts_tot , #
0 commit comments