@@ -235,24 +235,25 @@ def prefilled_group_gemm(
235235 recv_topk_weights : torch .Tensor ,
236236 hidden_dtype = torch .bfloat16 ,
237237 ):
238+ device = recv_x [0 ].device
238239 w1 , w1_scale = self .w1
239240 w2 , w2_scale = self .w2
240241 _ , K = recv_x [0 ].shape
241242 _ , N , _ = w1 .shape
242243 # scatter
243244 all_tokens = sum (num_recv_tokens_per_expert_list ) # calcu padding all nums.
244245 # gather_out shape [recive_num_tokens, hidden]
245- gather_out = torch .empty_like (recv_x [0 ], device = recv_x [ 0 ]. device , dtype = hidden_dtype )
246+ gather_out = torch .empty_like (recv_x [0 ], device = device , dtype = hidden_dtype )
246247 if all_tokens > 0 :
247248 input_tensor = [
248- torch .empty ((all_tokens , K ), device = recv_x [ 0 ]. device , dtype = recv_x [0 ].dtype ),
249- torch .empty ((all_tokens , K // 128 ), device = recv_x [ 0 ]. device , dtype = torch .float32 ),
249+ torch .empty ((all_tokens , K ), device = device , dtype = recv_x [0 ].dtype ),
250+ torch .empty ((all_tokens , K // 128 ), device = device , dtype = torch .float32 ),
250251 ]
251252 # when m_indices is filled ok.
252253 # m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..]
253254 # the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1]
254255 # ...
255- m_indices = torch .empty (all_tokens , device = recv_x [ 0 ]. device , dtype = torch .int32 )
256+ m_indices = torch .empty (all_tokens , device = device , dtype = torch .int32 )
256257 # output_index shape [recive_num_tokens, topk_num]
257258 # output_index use to show the token index in input_tensor
258259 output_index = torch .empty_like (recv_topk_idx )
@@ -276,19 +277,19 @@ def prefilled_group_gemm(
276277 )
277278 input_tensor [1 ] = tma_align_input_scale (input_tensor [1 ])
278279 # groupgemm (contiguous layout)
279- gemm_out_a = torch .empty ((all_tokens , N ), device = recv_x [ 0 ]. device , dtype = hidden_dtype )
280+ gemm_out_a = torch .empty ((all_tokens , N ), device = device , dtype = hidden_dtype )
280281
281282 deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (input_tensor , (w1 , w1_scale ), gemm_out_a , m_indices )
282283
283284 # silu_and_mul_fwd + qaunt
284285 # TODO fused kernel
285- silu_out = torch .empty ((all_tokens , N // 2 ), device = recv_x [ 0 ]. device , dtype = hidden_dtype )
286+ silu_out = torch .empty ((all_tokens , N // 2 ), device = device , dtype = hidden_dtype )
286287
287288 silu_and_mul_fwd (gemm_out_a .view (- 1 , N ), silu_out )
288289 qsilu_out , qsilu_out_scale = tma_aligned_quantize (silu_out )
289290
290291 # groupgemm (contiguous layout)
291- gemm_out_b = torch .empty ((all_tokens , K ), device = recv_x [ 0 ]. device , dtype = hidden_dtype )
292+ gemm_out_b = torch .empty ((all_tokens , K ), device = device , dtype = hidden_dtype )
292293
293294 deep_gemm .m_grouped_gemm_fp8_fp8_bf16_nt_contiguous (
294295 (qsilu_out , qsilu_out_scale ), (w2 , w2_scale ), gemm_out_b , m_indices
0 commit comments