@@ -715,25 +715,15 @@ def forward_fake(
715715 use_dp_padding : Optional [bool ] = None ,
716716 ** kwargs ,
717717 ) -> Union [torch .Tensor , List [torch .Tensor ]]:
718- if not self .enable_alltoall :
719- return super ().forward_fake (
720- x ,
721- router_logits ,
722- do_finalize = do_finalize ,
723- output_dtype = output_dtype ,
724- all_rank_num_tokens = all_rank_num_tokens ,
725- use_dp_padding = use_dp_padding ,
726- ** kwargs ,
727- )
728- else :
729- is_nvfp4_input = isinstance (x , Fp4QuantizedTensor )
730- data_type = output_dtype if is_nvfp4_input else x .dtype
731- num_tokens = all_rank_num_tokens [
732- self .mapping .tp_rank ] if all_rank_num_tokens else x .shape [0 ]
733- hidden_size = x .shape [1 ] * (2 if is_nvfp4_input else 1 )
734- top_k = self .routing_method .experts_per_token
735- return x .new_empty ((num_tokens , top_k , hidden_size ),
736- dtype = data_type )
718+ return super ().forward_fake (
719+ x ,
720+ router_logits ,
721+ do_finalize = do_finalize ,
722+ output_dtype = output_dtype ,
723+ all_rank_num_tokens = all_rank_num_tokens ,
724+ use_dp_padding = use_dp_padding ,
725+ ** kwargs ,
726+ )
737727
738728 def load_weights (self , weights : List [Dict ]):
739729 assert self ._weights_created
0 commit comments