3434from .moe_silu_and_mul import silu_and_mul_fwd
3535from .moe_sum_reduce import moe_sum_reduce
3636from lightllm .common .quantization .triton_quant .fp8 .fp8act_quant_kernel import per_token_group_quant_fp8
37+ from lightllm .utils .torch_ops_utils import direct_register_custom_op
3738
38- FFN_MOE_CHUNK_SIZE = 8 * 1024
39+ FFN_MOE_CHUNK_SIZE = 32 * 1024
3940
4041logger = init_logger (__name__ )
4142
@@ -355,7 +356,7 @@ def grouped_matmul_kernel(
355356 tile_n_idx = pid_n
356357
357358 # get the gemm size of the current problem
358- cur_m = tl .load (expert_to_token_num + expert_id , eviction_policy = "evict_last" )
359+ cur_m = tl .load (expert_to_token_num + expert_id )
359360
360361 # do regular gemm here
361362 offs_am = tile_m_idx * BLOCK_SIZE_M + tl .arange (0 , BLOCK_SIZE_M )
@@ -461,9 +462,8 @@ def grouped_matmul(
461462 out : torch .Tensor ,
462463 mul_routed_weight : bool ,
463464 use_fp8_w8a8 : bool ,
464- alloc_tensor_func = torch .empty ,
465465 reused_mblock_infos = None ,
466- ** run_config ,
466+ run_config : Optional [ dict ] = None ,
467467):
468468 """
469469 token_num_mul_topk_num is int equal token_num * topk_num,
@@ -493,7 +493,8 @@ def grouped_matmul(
493493 if expert_to_weights_scale .ndim == 3 :
494494 block_size_n = expert_weights .shape [1 ] // expert_to_weights_scale .shape [1 ]
495495 block_size_k = expert_weights .shape [2 ] // expert_to_weights_scale .shape [2 ]
496- if not run_config :
496+
497+ if run_config is None :
497498 run_config = MoeGroupedGemmKernelConfig .try_to_get_best_config (
498499 M = token_inputs .shape [0 ],
499500 N = n ,
@@ -525,8 +526,8 @@ def grouped_matmul(
525526 else :
526527 _m , _k = token_inputs .shape
527528 assert _k % block_size_k == 0
528- input_scale = alloc_tensor_func ((_m , _k // block_size_k ), dtype = torch .float32 , device = token_inputs .device )
529- qinput_tensor = alloc_tensor_func ((_m , _k ), dtype = expert_weights .dtype , device = token_inputs .device )
529+ input_scale = torch . empty ((_m , _k // block_size_k ), dtype = torch .float32 , device = token_inputs .device )
530+ qinput_tensor = torch . empty ((_m , _k ), dtype = expert_weights .dtype , device = token_inputs .device )
530531 per_token_group_quant_fp8 (token_inputs , block_size_k , qinput_tensor , input_scale )
531532 token_inputs , token_input_scale = qinput_tensor , input_scale
532533
@@ -611,8 +612,7 @@ def fused_experts_impl(
611612 w2_scale : Optional [torch .Tensor ] = None ,
612613 a1_scale : Optional [torch .Tensor ] = None ,
613614 a2_scale : Optional [torch .Tensor ] = None ,
614- alloc_tensor_func = torch .empty ,
615- ** run_config ,
615+ run_config : Optional [dict ] = None ,
616616):
617617 # Check constraints.
618618 assert hidden_states .shape [1 ] == w1 .shape [2 ], "Hidden size mismatch"
@@ -627,18 +627,16 @@ def fused_experts_impl(
627627 topk_num = topk_ids .shape [1 ]
628628 M = min (num_tokens , CHUNK_SIZE )
629629
630- intermediate_cache1 = alloc_tensor_func ((M , topk_num , N ), device = hidden_states .device , dtype = hidden_states .dtype )
631- intermediate_cache2 = alloc_tensor_func (
632- (M , topk_num , N // 2 ), device = hidden_states .device , dtype = hidden_states .dtype
633- )
634- intermediate_cache3 = alloc_tensor_func (
635- (M , topk_num , w2 .shape [1 ]), device = hidden_states .device , dtype = hidden_states .dtype
636- )
630+ cache = torch .empty (M * topk_num * max (N , w2 .shape [1 ]), device = hidden_states .device , dtype = hidden_states .dtype )
631+ intermediate_cache1 = cache [:M * topk_num * N ].view (M , topk_num , N )
632+ intermediate_cache2 = torch .empty ((M , topk_num , N // 2 ), device = hidden_states .device , dtype = hidden_states .dtype )
633+ intermediate_cache3 = cache [:M * topk_num * w2 .shape [1 ]].view (M , topk_num , w2 .shape [1 ])
634+
637635
638636 if inplace :
639637 out_hidden_states = hidden_states
640638 else :
641- out_hidden_states = alloc_tensor_func (
639+ out_hidden_states = torch . empty (
642640 hidden_states .shape , device = hidden_states .device , dtype = hidden_states .dtype
643641 )
644642
@@ -647,9 +645,10 @@ def fused_experts_impl(
647645 curr_hidden_states = hidden_states [begin_chunk_idx :end_chunk_idx ]
648646 tokens_in_chunk , _ = curr_hidden_states .shape
649647
650- intermediate_cache1 = intermediate_cache1 [:tokens_in_chunk ]
651- intermediate_cache2 = intermediate_cache2 [:tokens_in_chunk ]
652- intermediate_cache3 = intermediate_cache3 [:tokens_in_chunk ]
648+ if tokens_in_chunk < CHUNK_SIZE and chunk > 0 :
649+ intermediate_cache1 = intermediate_cache1 [:tokens_in_chunk ]
650+ intermediate_cache2 = intermediate_cache2 [:tokens_in_chunk ]
651+ intermediate_cache3 = intermediate_cache3 [:tokens_in_chunk ]
653652
654653 curr_topk_ids = topk_ids [begin_chunk_idx :end_chunk_idx ]
655654 curr_topk_weights = topk_weights [begin_chunk_idx :end_chunk_idx ]
@@ -673,8 +672,7 @@ def fused_experts_impl(
673672 out = intermediate_cache1 .view (- 1 , N ),
674673 mul_routed_weight = False ,
675674 use_fp8_w8a8 = use_fp8_w8a8 ,
676- alloc_tensor_func = alloc_tensor_func ,
677- ** run_config ,
675+ run_config = run_config ,
678676 )
679677
680678 silu_and_mul_fwd (intermediate_cache1 .view (- 1 , N ), intermediate_cache2 .view (- 1 , N // 2 ))
@@ -692,12 +690,156 @@ def fused_experts_impl(
692690 out = intermediate_cache3 .view (- 1 , w2 .shape [1 ]),
693691 mul_routed_weight = True ,
694692 use_fp8_w8a8 = use_fp8_w8a8 ,
695- alloc_tensor_func = alloc_tensor_func ,
696693 reused_mblock_infos = reused_mblock_infos ,
697- ** run_config ,
694+ run_config = run_config ,
698695 )
699696
700697 moe_sum_reduce (
701698 intermediate_cache3 .view (* intermediate_cache3 .shape ), out_hidden_states [begin_chunk_idx :end_chunk_idx ]
702699 )
703700 return out_hidden_states
701+
702+
703+ def inplace_fused_experts_impl (
704+ hidden_states : torch .Tensor ,
705+ w1 : torch .Tensor ,
706+ w2 : torch .Tensor ,
707+ topk_weights : torch .Tensor ,
708+ topk_ids : torch .Tensor ,
709+ use_fp8_w8a8 : bool = False ,
710+ use_int8_w8a16 : bool = False ,
711+ w1_scale : Optional [torch .Tensor ] = None ,
712+ w2_scale : Optional [torch .Tensor ] = None ,
713+ a1_scale : Optional [torch .Tensor ] = None ,
714+ a2_scale : Optional [torch .Tensor ] = None ,
715+ )-> None :
716+ fused_experts_impl (
717+ hidden_states ,
718+ w1 ,
719+ w2 ,
720+ topk_weights ,
721+ topk_ids ,
722+ True ,
723+ use_fp8_w8a8 ,
724+ use_int8_w8a16 ,
725+ w1_scale ,
726+ w2_scale ,
727+ a1_scale ,
728+ a2_scale ,
729+ )
730+
731+ def inplace_fused_experts_impl_fake (
732+ hidden_states : torch .Tensor ,
733+ w1 : torch .Tensor ,
734+ w2 : torch .Tensor ,
735+ topk_weights : torch .Tensor ,
736+ topk_ids : torch .Tensor ,
737+ use_fp8_w8a8 : bool = False ,
738+ use_int8_w8a16 : bool = False ,
739+ w1_scale : Optional [torch .Tensor ] = None ,
740+ w2_scale : Optional [torch .Tensor ] = None ,
741+ a1_scale : Optional [torch .Tensor ] = None ,
742+ a2_scale : Optional [torch .Tensor ] = None ,
743+ )-> None :
744+ pass
745+
746+ direct_register_custom_op (
747+ "inplace_fused_experts_impl" ,
748+ inplace_fused_experts_impl ,
749+ ["hidden_states" ],
750+ inplace_fused_experts_impl_fake ,
751+ )
752+
753+ def outplace_fused_experts_impl (
754+ hidden_states : torch .Tensor ,
755+ w1 : torch .Tensor ,
756+ w2 : torch .Tensor ,
757+ topk_weights : torch .Tensor ,
758+ topk_ids : torch .Tensor ,
759+ use_fp8_w8a8 : bool = False ,
760+ use_int8_w8a16 : bool = False ,
761+ w1_scale : Optional [torch .Tensor ] = None ,
762+ w2_scale : Optional [torch .Tensor ] = None ,
763+ a1_scale : Optional [torch .Tensor ] = None ,
764+ a2_scale : Optional [torch .Tensor ] = None ,
765+ )-> None :
766+ return fused_experts_impl (
767+ hidden_states ,
768+ w1 ,
769+ w2 ,
770+ topk_weights ,
771+ topk_ids ,
772+ False ,
773+ use_fp8_w8a8 ,
774+ use_int8_w8a16 ,
775+ w1_scale ,
776+ w2_scale ,
777+ a1_scale ,
778+ a2_scale ,
779+ )
780+
781+ def outplace_fused_experts_impl_fake (
782+ hidden_states : torch .Tensor ,
783+ w1 : torch .Tensor ,
784+ w2 : torch .Tensor ,
785+ topk_weights : torch .Tensor ,
786+ topk_ids : torch .Tensor ,
787+ use_fp8_w8a8 : bool = False ,
788+ use_int8_w8a16 : bool = False ,
789+ w1_scale : Optional [torch .Tensor ] = None ,
790+ w2_scale : Optional [torch .Tensor ] = None ,
791+ a1_scale : Optional [torch .Tensor ] = None ,
792+ a2_scale : Optional [torch .Tensor ] = None ,
793+ )-> None :
794+ return torch .empty_like (hidden_states )
795+
796+ direct_register_custom_op (
797+ "outplace_fused_experts_impl" ,
798+ outplace_fused_experts_impl ,
799+ [],
800+ outplace_fused_experts_impl_fake ,
801+ )
802+
803+ def fused_experts (
804+ hidden_states : torch .Tensor ,
805+ w1 : torch .Tensor ,
806+ w2 : torch .Tensor ,
807+ topk_weights : torch .Tensor ,
808+ topk_ids : torch .Tensor ,
809+ inplace : bool = False ,
810+ use_fp8_w8a8 : bool = False ,
811+ use_int8_w8a16 : bool = False ,
812+ w1_scale : Optional [torch .Tensor ] = None ,
813+ w2_scale : Optional [torch .Tensor ] = None ,
814+ a1_scale : Optional [torch .Tensor ] = None ,
815+ a2_scale : Optional [torch .Tensor ] = None ,
816+ ):
817+ if inplace :
818+ torch .ops .lightllm .inplace_fused_experts_impl (
819+ hidden_states ,
820+ w1 ,
821+ w2 ,
822+ topk_weights ,
823+ topk_ids ,
824+ use_fp8_w8a8 ,
825+ use_int8_w8a16 ,
826+ w1_scale ,
827+ w2_scale ,
828+ a1_scale ,
829+ a2_scale ,
830+ )
831+ return hidden_states
832+ else :
833+ return torch .ops .lightllm .outplace_fused_experts_impl (
834+ hidden_states ,
835+ w1 ,
836+ w2 ,
837+ topk_weights ,
838+ topk_ids ,
839+ use_fp8_w8a8 ,
840+ use_int8_w8a16 ,
841+ w1_scale ,
842+ w2_scale ,
843+ a1_scale ,
844+ a2_scale ,
845+ )
0 commit comments