Skip to content

Commit 4740419

Browse files
authored
[None][fix] Enabled simultaneous support for low-precision combine and MTP. (NVIDIA#9091)
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
1 parent 0dbf394 commit 4740419

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,13 @@ def is_post_quant_all2all_supported(self):
382382
else:
383383
return False
384384

385+
def is_low_precision_combine_supported(self):
386+
if not self.use_low_precision_combine:
387+
return False
388+
if self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
389+
return self.has_fp8_qdq or self.has_nvfp4 or self.has_w4afp8
390+
return False
391+
385392
def forward_chunk(
386393
self,
387394
x: Union[torch.Tensor, Fp4QuantizedTensor],
@@ -671,8 +678,7 @@ def forward_chunk(
671678
final_hidden_states = final_hidden_states.view(
672679
self.expert_size_per_partition,
673680
num_tokens_per_expert_for_fused_moe, self.hidden_size)
674-
if self.use_low_precision_combine:
675-
assert self.has_nvfp4 or self.has_w4afp8 or self.has_fp8_qdq, "Low precision combine only supports nvfp4, w4afp8 and fp8 qdq"
681+
if self.is_low_precision_combine_supported():
676682
precision = "fp8"
677683
global_scales = None
678684
if self.has_nvfp4:

0 commit comments

Comments
 (0)