Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 159 additions & 56 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os
from functools import cached_property
from typing import Dict, List, Optional, Union

import torch
from torch import nn

from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe
from tensorrt_llm._utils import get_sm_version

from ...custom_ops.trtllm_gen_custom_ops import \
Expand Down Expand Up @@ -106,10 +109,28 @@ def __init__(
assert len(
self.initial_local_expert_ids) == self.expert_size_per_partition

self.alltoall_workspace = None
self.alltoall_prepare_workspace = None
if self.enable_alltoall:
MnnvlMemory.initialize()
self.alltoall_workspace = MnnvlMoe.get_moe_workspaces(
model_config.mapping)
self.alltoall_prepare_workspace = MnnvlMoe.get_moe_prepare_workspace(
model_config.mapping)

self._weights_created = False
if not model_config.skip_create_weights_in_init:
self.create_weights()

@cached_property
def enable_alltoall(self):
mapping = self.mapping
routing_experts = self.routing_method.experts_per_token
return (mapping.moe_ep_size > routing_experts
and mapping.enable_attention_dp and mapping.tp_size > 1
and os.environ.get("TRTLLM_MOE_DISABLE_ALLTOALLV", "0") != "1"
and MnnvlMemory.supports_mnnvl())

def _check_configs(self):
assert self.has_deepseek_fp8_block_scales \
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
Expand Down Expand Up @@ -175,6 +196,48 @@ def load_weights(self, weights: List[Dict]):
def post_load_weights(self):
self.quant_method.post_load_weights(self)

def _quantize_for_post_quant_comm(self, x):
"""Quantize inputs prior to post-communication (alltoall/allgather).
Returns: (x, x_sf, x_row, x_col)
"""
x_row = x.shape[0]
x_col = x.shape[1]
x_sf = None
if self.has_w4a8_mxfp4_fp8:
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_dequant[0])
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_nvfp4:
if isinstance(x, Fp4QuantizedTensor):
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
x_row = x.shape[0]
x_col = x.shape[1] * 2
x, x_sf = x.fp4_tensor, x.scaling_factor
else:
x_row = x.shape[0]
x_col = x.shape[1]
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size, False,
False)
elif self.has_w4a8_mxfp4_mxfp8:
x, x_sf = torch.ops.trtllm.mxfp8_quantize(
x, False, alignment=self.quant_method.weight_alignment)
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_deepseek_fp8_block_scales:
# No change required before communication
pass
elif self.has_w4a16_mxfp4:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
elif self.has_w4a8_nvfp4_fp8:
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, 1.0 / self.fc31_input_scale)
else:
raise ValueError(
f"unsupported quantization mode for post communication: {self.quant_config.quant_mode}"
)
return x, x_sf, x_row, x_col

def forward_impl(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
Expand Down Expand Up @@ -202,55 +265,80 @@ def forward_impl(
topk_group = None
routed_scaling_factor = None

run_post_quant_allgather = self.use_dp and self.parallel_size > 1
run_post_quant_allgather = (self.use_dp and self.parallel_size > 1
and not self.enable_alltoall)
post_quant_comm = run_post_quant_allgather or self.enable_alltoall

x_sf = None
token_selected_experts = None
token_final_scales = None
x_row = x.shape[0]
x_col = x.shape[1]
if run_post_quant_allgather:
# apply routing
token_count = x.shape[0]
alltoall_info = None

if post_quant_comm:
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
token_final_scales = token_final_scales.to(torch.bfloat16)
assert token_final_scales.dtype == torch.bfloat16
assert token_selected_experts.dtype == torch.int32
# quantize inputs
if self.has_w4a8_mxfp4_fp8:
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_dequant[0])
# Update x_row and x_col to the padded shape
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_nvfp4:
if isinstance(x, Fp4QuantizedTensor):
assert not x.is_sf_swizzled, "Fp4QuantizedTensor should not be swizzled before communication"
x_row = x.shape[0]
# note: we use uint8 to store 2 fp4 values
x_col = x.shape[1] * 2
x, x_sf = x.fp4_tensor, x.scaling_factor
else:
x_row = x.shape[0]
x_col = x.shape[1]
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size,
False, False)
elif self.has_w4a8_mxfp4_mxfp8:
x, x_sf = torch.ops.trtllm.mxfp8_quantize(
x, False, alignment=self.quant_method.weight_alignment)
# Update x_row and x_col to the padded shape
x_row, x_col = x.shape[0], x.shape[1]
elif self.has_deepseek_fp8_block_scales:
pass
elif self.has_w4a16_mxfp4:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
token_selected_experts = token_selected_experts.to(torch.int32)
if token_final_scales is not None:
token_final_scales = token_final_scales.to(torch.bfloat16)

x, x_sf, x_row, x_col = self._quantize_for_post_quant_comm(x)

if self.enable_alltoall:
assert all_rank_num_tokens is not None, "all_rank_num_tokens required for alltoall"

max_num_token = max(
all_rank_num_tokens) if all_rank_num_tokens else token_count

if token_final_scales is None:
token_final_scales = torch.ones_like(token_selected_experts,
dtype=torch.float32)
else:
raise ValueError(
f"unsupported quantization mode with run_post_quant_allgather: {self.quant_config.quant_mode}"
)
token_final_scales = token_final_scales.to(torch.float32)

assert self.alltoall_prepare_workspace is not None, "alltoall_prepare_workspace should be initialized"
alltoall_info, _ = MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
token_selected_experts,
None,
self.alltoall_prepare_workspace,
max_num_token,
self.ep_rank,
self.ep_size,
self.num_experts,
self.num_slots,
top_k,
)

#allgather for attention DP
if x_sf is not None:
x_sf = x_sf.view(x_row, ceil_div(x_col,
self.scaling_vector_size))

x, x_sf, token_selected_experts, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv(
[x, x_sf, token_selected_experts, token_final_scales],
alltoall_info,
self.alltoall_workspace,
self.ep_rank,
self.ep_size,
)

torch.ops.trtllm.memset_expert_ids(
token_selected_experts,
alltoall_info.recv_rank_count_cumsum,
max_num_token,
top_k,
self.num_slots,
self.ep_size,
)

if x_sf is not None:
x_sf = x_sf.flatten()

if token_final_scales is not None:
token_final_scales = token_final_scales.to(torch.bfloat16)

elif run_post_quant_allgather:
if x_sf is not None:
x_sf = x_sf.view(x_row, ceil_div(x_col,
self.scaling_vector_size))
Expand All @@ -265,15 +353,18 @@ def forward_impl(
if x_sf is not None:
x_sf = x_sf.flatten()

router_logits_arg = router_logits if not post_quant_comm else None
routing_bias_arg = routing_bias if not post_quant_comm else None

# TODO: since routing kernel is integrated into moe_runner for fp8,
# here we just route the I/Os for moe_runner
if self.has_deepseek_fp8_block_scales:
assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False"
x_val, x_scale = torch.ops.trtllm.fp8_quantize_1x128(x)

final_hidden_states = torch.ops.trtllm.fp8_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
x_val,
x_scale,
self.w3_w1_weight,
Expand All @@ -297,7 +388,7 @@ def forward_impl(
scale_factor_use_ue8m0 = False
is_scale_factor_swizzled = False # use linear layout here

if not run_post_quant_allgather:
if not post_quant_comm:
hidden_states_fp4, hidden_states_scale_linear_fp4 = (
torch.ops.trtllm.fp4_quantize(
x,
Expand All @@ -310,8 +401,8 @@ def forward_impl(
hidden_states_fp4, hidden_states_scale_linear_fp4 = x, x_sf

outputs = torch.ops.trtllm.fp4_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
hidden_states_fp4,
hidden_states_scale_linear_fp4.view(torch.float8_e4m3fn),
self.w3_w1_weight,
Expand Down Expand Up @@ -343,7 +434,7 @@ def forward_impl(
final_hidden_states = outputs[0]
elif self.has_w4a16_mxfp4:
assert x.dtype == torch.bfloat16
if not run_post_quant_allgather:
if not post_quant_comm:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
x = torch.nn.functional.pad(x, (0, pad_size))
else:
Expand All @@ -352,8 +443,8 @@ def forward_impl(
intermediate_size_per_partition_padded = self.w3_w1_weight.shape[
-2] // 2
final_hidden_states = torch.ops.trtllm.bf16_mxe2m1_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
x,
self.w3_w1_weight,
self.w3_w1_weight_scale,
Expand Down Expand Up @@ -383,15 +474,15 @@ def forward_impl(
hidden_size].contiguous()
elif self.has_w4a8_nvfp4_fp8:

if not run_post_quant_allgather:
if not post_quant_comm:
hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, 1.0 / self.fc31_input_scale)
else:
hidden_states_fp8 = x

outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
router_logits,
routing_bias,
router_logits_arg,
routing_bias_arg,
hidden_states_fp8,
self.w3_w1_weight,
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
Expand Down Expand Up @@ -423,7 +514,7 @@ def forward_impl(
final_hidden_states = outputs[0]
elif self.has_w4a8_mxfp4_fp8:
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
if not run_post_quant_allgather:
if not post_quant_comm:
x = torch.nn.functional.pad(x, (0, pad_size))
x, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
x, self.fc31_input_gate_dequant[0])
Expand All @@ -433,8 +524,8 @@ def forward_impl(
-2] // 2

final_hidden_states = torch.ops.trtllm.e4m3_mxe2m1_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
x,
self.w3_w1_weight,
self.w3_w1_weight_scale,
Expand Down Expand Up @@ -466,7 +557,7 @@ def forward_impl(
final_hidden_states = final_hidden_states[:, :self.
hidden_size].contiguous()
elif self.has_w4a8_mxfp4_mxfp8:
if not run_post_quant_allgather:
if not post_quant_comm:
# TRTLLM-Gen uses linear SF layout for the mxfp8 input.
mxfp8_x, sf = torch.ops.trtllm.mxfp8_quantize(
x, False, alignment=self.quant_method.weight_alignment)
Expand All @@ -477,8 +568,8 @@ def forward_impl(
-2] // 2

final_hidden_states = torch.ops.trtllm.mxe4m3_mxe2m1_block_scale_moe_runner(
router_logits if not run_post_quant_allgather else None,
routing_bias if not run_post_quant_allgather else None,
router_logits_arg,
routing_bias_arg,
mxfp8_x,
sf,
self.w3_w1_weight,
Expand Down Expand Up @@ -511,6 +602,18 @@ def forward_impl(
"TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_mxfp8 and w4a8_mxfp4_fp8 dtypes."
)

# Combine results if using alltoall
if self.enable_alltoall and alltoall_info is not None:
final_hidden_states = MnnvlMoe.mnnvl_moe_alltoallv_combine(
final_hidden_states,
alltoall_info,
self.alltoall_workspace,
ep_rank=self.ep_rank,
ep_size=self.ep_size,
top_k=top_k,
token_count=token_count,
)

final_hidden_states = self.reducescatter_or_allreduce(
final_hidden_states,
all_rank_num_tokens=all_rank_num_tokens,
Expand Down
Loading