@@ -417,6 +417,18 @@ def reducescatter_or_allreduce(
417417 sizes = None if use_dp_padding else all_rank_num_tokens )
418418 return outputs
419419
420+ def is_post_quant_all2all_supported (self ):
421+ if not self .use_postquant_alltoall :
422+ return False
423+ if self .alltoall_method_type == AlltoallMethodType .MNNVL :
424+ return False
425+ elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
426+ return self .has_nvfp4
427+ elif self .alltoall_method_type == AlltoallMethodType .DeepEPLowLatency :
428+ return self .has_fp8_qdq or self .has_nvfp4 or self .has_w4afp8
429+ else :
430+ return False
431+
420432 def forward_chunk (
421433 self ,
422434 x : Union [torch .Tensor , Fp4QuantizedTensor ],
@@ -493,7 +505,8 @@ def forward_chunk(
493505 use_allgather = not use_all_to_all
494506
495507 # If alltoall is disabled, we need also disable use_postquant_alltoall
496- use_postquant_alltoall = self .use_postquant_alltoall and use_all_to_all and self .has_any_quant
508+ use_postquant_alltoall = use_all_to_all and self .is_post_quant_all2all_supported (
509+ )
497510
498511 # Prepare additional information for profiling in case padding is applied when using alltoall.
499512 # Only the non-alltoall case is considered for profiling in the warmup phase.
@@ -613,6 +626,7 @@ def forward_chunk(
613626 if self .alltoall_method_type == AlltoallMethodType .MNNVL :
614627 pass
615628 elif self .alltoall_method_type == AlltoallMethodType .DeepEP :
629+ assert self .has_nvfp4 , "DeepEP postquant alltoall should have nvfp4"
616630 if x_sf is not None :
617631 # Adapter between `x_sf` and DeepEP
618632 # TODO: remove the adapter by adding dtype support to DeepEP
0 commit comments