@@ -462,6 +462,7 @@ def grouped_matmul(
462462 out : torch .Tensor ,
463463 mul_routed_weight : bool ,
464464 use_fp8_w8a8 : bool ,
465+ alloc_tensor_func = torch .empty ,
465466 reused_mblock_infos = None ,
466467 run_config : Optional [dict ] = None ,
467468):
@@ -525,8 +526,8 @@ def grouped_matmul(
525526 else :
526527 _m , _k = token_inputs .shape
527528 assert _k % block_size_k == 0
528- input_scale = torch . empty ((_m , _k // block_size_k ), dtype = torch .float32 , device = token_inputs .device )
529- qinput_tensor = torch . empty ((_m , _k ), dtype = expert_weights .dtype , device = token_inputs .device )
529+ input_scale = alloc_tensor_func ((_m , _k // block_size_k ), dtype = torch .float32 , device = token_inputs .device )
530+ qinput_tensor = alloc_tensor_func ((_m , _k ), dtype = expert_weights .dtype , device = token_inputs .device )
530531 per_token_group_quant_fp8 (token_inputs , block_size_k , qinput_tensor , input_scale )
531532 token_inputs , token_input_scale = qinput_tensor , input_scale
532533
@@ -611,6 +612,7 @@ def fused_experts_impl(
611612 w2_scale : Optional [torch .Tensor ] = None ,
612613 a1_scale : Optional [torch .Tensor ] = None ,
613614 a2_scale : Optional [torch .Tensor ] = None ,
615+ alloc_tensor_func = torch .empty ,
614616 run_config : Optional [dict ] = None ,
615617):
616618 # Check constraints.
@@ -625,26 +627,29 @@ def fused_experts_impl(
625627 CHUNK_SIZE = FFN_MOE_CHUNK_SIZE
626628 topk_num = topk_ids .shape [1 ]
627629 M = min (num_tokens , CHUNK_SIZE )
628-
629- cache = torch .empty (M * topk_num * max (N , w2 .shape [1 ]), device = hidden_states .device , dtype = hidden_states .dtype )
630- intermediate_cache1 = cache [: M * topk_num * N ].view (M , topk_num , N )
631- intermediate_cache2 = torch .empty ((M , topk_num , N // 2 ), device = hidden_states .device , dtype = hidden_states .dtype )
632- intermediate_cache3 = cache [: M * topk_num * w2 .shape [1 ]].view (M , topk_num , w2 .shape [1 ])
630+
631+ intermediate_cache13_shared = alloc_tensor_func ((M , topk_num , max (N , w2 .shape [1 ])), device = hidden_states .device , dtype = hidden_states .dtype )
632+ intermediate_cache1 = intermediate_cache13_shared .view (- 1 )[:(M * topk_num * N )].view (M , topk_num , N )
633+ intermediate_cache2 = alloc_tensor_func (
634+ (M , topk_num , N // 2 ), device = hidden_states .device , dtype = hidden_states .dtype
635+ )
636+ intermediate_cache3 = intermediate_cache13_shared .view (- 1 )[:(M * topk_num * w2 .shape [1 ])].view (M , topk_num , w2 .shape [1 ])
633637
634638 if inplace :
635639 out_hidden_states = hidden_states
636640 else :
637- out_hidden_states = torch .empty (hidden_states .shape , device = hidden_states .device , dtype = hidden_states .dtype )
641+ out_hidden_states = alloc_tensor_func (
642+ hidden_states .shape , device = hidden_states .device , dtype = hidden_states .dtype
643+ )
638644
639645 for chunk in range (triton .cdiv (num_tokens , CHUNK_SIZE )):
640646 begin_chunk_idx , end_chunk_idx = (chunk * CHUNK_SIZE , min ((chunk + 1 ) * CHUNK_SIZE , num_tokens ))
641647 curr_hidden_states = hidden_states [begin_chunk_idx :end_chunk_idx ]
642648 tokens_in_chunk , _ = curr_hidden_states .shape
643649
644- if tokens_in_chunk < CHUNK_SIZE and chunk > 0 :
645- intermediate_cache1 = intermediate_cache1 [:tokens_in_chunk ]
646- intermediate_cache2 = intermediate_cache2 [:tokens_in_chunk ]
647- intermediate_cache3 = intermediate_cache3 [:tokens_in_chunk ]
650+ intermediate_cache1 = intermediate_cache1 [:tokens_in_chunk ]
651+ intermediate_cache2 = intermediate_cache2 [:tokens_in_chunk ]
652+ intermediate_cache3 = intermediate_cache3 [:tokens_in_chunk ]
648653
649654 curr_topk_ids = topk_ids [begin_chunk_idx :end_chunk_idx ]
650655 curr_topk_weights = topk_weights [begin_chunk_idx :end_chunk_idx ]
@@ -668,6 +673,7 @@ def fused_experts_impl(
668673 out = intermediate_cache1 .view (- 1 , N ),
669674 mul_routed_weight = False ,
670675 use_fp8_w8a8 = use_fp8_w8a8 ,
676+ alloc_tensor_func = alloc_tensor_func ,
671677 run_config = run_config ,
672678 )
673679
@@ -686,6 +692,7 @@ def fused_experts_impl(
686692 out = intermediate_cache3 .view (- 1 , w2 .shape [1 ]),
687693 mul_routed_weight = True ,
688694 use_fp8_w8a8 = use_fp8_w8a8 ,
695+ alloc_tensor_func = alloc_tensor_func ,
689696 reused_mblock_infos = reused_mblock_infos ,
690697 run_config = run_config ,
691698 )
0 commit comments