@@ -387,6 +387,7 @@ def grouped_matmul(
387387
388388 expert_num , n , k = expert_weights .shape
389389 assert token_inputs .shape [1 ] == k
390+ assert expert_to_weights_scale .shape [0 ] == expert_num
390391 assert expert_to_token_index .shape == expert_to_weights .shape
391392 assert token_inputs .is_contiguous ()
392393 assert expert_to_token_num .is_contiguous ()
@@ -520,7 +521,7 @@ def fused_experts_impl(
520521
521522 intermediate_cache1 = alloc_tensor_func ((M , topk_num , N ), device = hidden_states .device , dtype = hidden_states .dtype )
522523 intermediate_cache2 = alloc_tensor_func (
523- (M * topk_num , N // 2 ), device = hidden_states .device , dtype = hidden_states .dtype
524+ (M , topk_num , N // 2 ), device = hidden_states .device , dtype = hidden_states .dtype
524525 )
525526 intermediate_cache3 = alloc_tensor_func (
526527 (M , topk_num , w2 .shape [1 ]), device = hidden_states .device , dtype = hidden_states .dtype
@@ -567,10 +568,10 @@ def fused_experts_impl(
567568 ** run_config ,
568569 )
569570
570- ops .silu_and_mul (intermediate_cache2 , intermediate_cache1 .view (- 1 , N ))
571+ ops .silu_and_mul (intermediate_cache2 . view ( - 1 , N // 2 ) , intermediate_cache1 .view (- 1 , N ))
571572
572573 grouped_matmul (
573- intermediate_cache2 ,
574+ intermediate_cache2 . view ( - 1 , N // 2 ) ,
574575 a2_scale ,
575576 expert_to_token_num ,
576577 expert_to_tokens ,
0 commit comments