@@ -493,7 +493,7 @@ def grouped_matmul(
493493 if expert_to_weights_scale .ndim == 3 :
494494 block_size_n = expert_weights .shape [1 ] // expert_to_weights_scale .shape [1 ]
495495 block_size_k = expert_weights .shape [2 ] // expert_to_weights_scale .shape [2 ]
496-
496+
497497 if run_config is None :
498498 run_config = MoeGroupedGemmKernelConfig .try_to_get_best_config (
499499 M = token_inputs .shape [0 ],
@@ -626,18 +626,15 @@ def fused_experts_impl(
626626 topk_num = topk_ids .shape [1 ]
627627 M = min (num_tokens , CHUNK_SIZE )
628628
629- cache = torch .empty (M * topk_num * max (N , w2 .shape [1 ]), device = hidden_states .device , dtype = hidden_states .dtype )
630- intermediate_cache1 = cache [:M * topk_num * N ].view (M , topk_num , N )
629+ cache = torch .empty (M * topk_num * max (N , w2 .shape [1 ]), device = hidden_states .device , dtype = hidden_states .dtype )
630+ intermediate_cache1 = cache [: M * topk_num * N ].view (M , topk_num , N )
631631 intermediate_cache2 = torch .empty ((M , topk_num , N // 2 ), device = hidden_states .device , dtype = hidden_states .dtype )
632- intermediate_cache3 = cache [:M * topk_num * w2 .shape [1 ]].view (M , topk_num , w2 .shape [1 ])
633-
632+ intermediate_cache3 = cache [: M * topk_num * w2 .shape [1 ]].view (M , topk_num , w2 .shape [1 ])
634633
635634 if inplace :
636635 out_hidden_states = hidden_states
637636 else :
638- out_hidden_states = torch .empty (
639- hidden_states .shape , device = hidden_states .device , dtype = hidden_states .dtype
640- )
637+ out_hidden_states = torch .empty (hidden_states .shape , device = hidden_states .device , dtype = hidden_states .dtype )
641638
642639 for chunk in range (triton .cdiv (num_tokens , CHUNK_SIZE )):
643640 begin_chunk_idx , end_chunk_idx = (chunk * CHUNK_SIZE , min ((chunk + 1 ) * CHUNK_SIZE , num_tokens ))
@@ -711,7 +708,7 @@ def inplace_fused_experts_impl(
711708 w2_scale : Optional [torch .Tensor ] = None ,
712709 a1_scale : Optional [torch .Tensor ] = None ,
713710 a2_scale : Optional [torch .Tensor ] = None ,
714- )-> None :
711+ ) -> None :
715712 fused_experts_impl (
716713 hidden_states ,
717714 w1 ,
@@ -727,6 +724,7 @@ def inplace_fused_experts_impl(
727724 a2_scale ,
728725 )
729726
727+
730728def inplace_fused_experts_impl_fake (
731729 hidden_states : torch .Tensor ,
732730 w1 : torch .Tensor ,
@@ -739,16 +737,18 @@ def inplace_fused_experts_impl_fake(
739737 w2_scale : Optional [torch .Tensor ] = None ,
740738 a1_scale : Optional [torch .Tensor ] = None ,
741739 a2_scale : Optional [torch .Tensor ] = None ,
742- )-> None :
740+ ) -> None :
743741 pass
744742
743+
745744direct_register_custom_op (
746745 "inplace_fused_experts_impl" ,
747746 inplace_fused_experts_impl ,
748747 ["hidden_states" ],
749748 inplace_fused_experts_impl_fake ,
750749)
751750
751+
752752def outplace_fused_experts_impl (
753753 hidden_states : torch .Tensor ,
754754 w1 : torch .Tensor ,
@@ -761,7 +761,7 @@ def outplace_fused_experts_impl(
761761 w2_scale : Optional [torch .Tensor ] = None ,
762762 a1_scale : Optional [torch .Tensor ] = None ,
763763 a2_scale : Optional [torch .Tensor ] = None ,
764- )-> None :
764+ ) -> None :
765765 return fused_experts_impl (
766766 hidden_states ,
767767 w1 ,
@@ -777,6 +777,7 @@ def outplace_fused_experts_impl(
777777 a2_scale ,
778778 )
779779
780+
780781def outplace_fused_experts_impl_fake (
781782 hidden_states : torch .Tensor ,
782783 w1 : torch .Tensor ,
@@ -789,16 +790,18 @@ def outplace_fused_experts_impl_fake(
789790 w2_scale : Optional [torch .Tensor ] = None ,
790791 a1_scale : Optional [torch .Tensor ] = None ,
791792 a2_scale : Optional [torch .Tensor ] = None ,
792- )-> None :
793+ ) -> None :
793794 return torch .empty_like (hidden_states )
794795
796+
795797direct_register_custom_op (
796798 "outplace_fused_experts_impl" ,
797799 outplace_fused_experts_impl ,
798800 [],
799801 outplace_fused_experts_impl_fake ,
800802)
801803
804+
802805def fused_experts (
803806 hidden_states : torch .Tensor ,
804807 w1 : torch .Tensor ,
@@ -825,7 +828,7 @@ def fused_experts(
825828 w1_scale ,
826829 w2_scale ,
827830 a1_scale ,
828- a2_scale ,
831+ a2_scale ,
829832 )
830833 return hidden_states
831834 else :
0 commit comments