Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class NemotronHHfWeightMapper(HfWeightMapper):

def preprocess_weights(self, weights: dict) -> dict:
config = self.config.pretrained_config
tp_size = self.config.mapping.tp_size
tp_size = 1 if self.config.mapping.enable_attention_dp else self.config.mapping.tp_size
tp_rank = self.config.mapping.tp_rank
d_inner = config.mamba_head_dim * config.mamba_num_heads

Expand Down
52 changes: 38 additions & 14 deletions tensorrt_llm/_torch/models/modeling_nemotron_h.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import MoEWeightLoadingMode, create_moe
from ..modules.linear import Linear
from ..modules.linear import Linear, TensorParallelMode
from ..modules.mamba.mamba2_mixer import Mamba2Mixer
from ..modules.mlp import MLP
from ..modules.multi_stream_utils import maybe_execute_in_parallel
Expand Down Expand Up @@ -85,8 +85,10 @@ def __init__(
self,
model_config: ModelConfig[NemotronHConfig],
layer_idx: int,
reduce_output: bool = False,
):
config = model_config.pretrained_config

super().__init__(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
Expand All @@ -97,6 +99,7 @@ def __init__(
layer_idx=layer_idx,
dtype=config.torch_dtype,
config=model_config,
reduce_output=reduce_output,
)

def forward(
Expand Down Expand Up @@ -154,6 +157,7 @@ def __init__(
shared_expert_intermediate_size = (
config.moe_shared_expert_intermediate_size *
config.n_shared_experts)

self.shared_experts = MLP(
hidden_size=config.hidden_size,
intermediate_size=shared_expert_intermediate_size,
Expand All @@ -163,7 +167,8 @@ def __init__(
config=model_config,
layer_idx=self.layer_idx,
reduce_output=False,
)
overridden_tp_size=1
if model_config.mapping.enable_attention_dp else None)
# Setup MoE gate.
self.gate = DeepseekV3Gate(
self.hidden_size,
Expand Down Expand Up @@ -193,11 +198,14 @@ def __init__(
activation_type=self.activation_type,
)

# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
self.allreduce = AllReduce(
mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
)
if not model_config.mapping.enable_attention_dp:
# AllReduce for combining shared and routed expert outputs in multi-GPU settings.
self.allreduce = AllReduce(
mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
)
else:
self.allreduce = None

# Setup latent projection layers.
# These layers should NOT be TP-sharded to ensure MoE receives
Expand Down Expand Up @@ -322,7 +330,11 @@ def __init__(
elif layer_type == "-":
self.mixer = MLPLayer(model_config, layer_idx)
elif layer_type == "*":
self.mixer = TransformerLayer(model_config, layer_idx)
self.mixer = TransformerLayer(
model_config,
layer_idx,
reduce_output=not model_config.mapping.enable_attention_dp
and model_config.mapping.tp_size > 1)
elif layer_type == "E":
self.mixer = NemotronHMOE(model_config,
layer_idx=layer_idx,
Expand Down Expand Up @@ -365,12 +377,24 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
aux_stream_list[2],
}

# calculate embeddings
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
)
if model_config.mapping.enable_attention_dp:
# When attention_dp is enabled, we cannot do all_reduce since
# the problem size of different ranks are different.
# So, we don't do parallelism here.
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
)
else:
self.embed_tokens = Embedding(
config.vocab_size,
config.hidden_size,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gather_output=True,
)

# create layers
layers = []
Expand Down
21 changes: 15 additions & 6 deletions tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch import nn

from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
from tensorrt_llm.mapping import Mapping

from ...attention_backend import AttentionMetadata
from ...model_config import ModelConfig
Expand Down Expand Up @@ -57,8 +58,20 @@ def __init__(

config = config or ModelConfig()
self.mapping = config.mapping
tp_rank = config.mapping.tp_rank
tp_size = config.mapping.tp_size

if config.mapping.enable_attention_dp:
self.mapping = Mapping(
world_size=config.mapping.pp_size,
tp_size=1,
pp_size=config.mapping.pp_size,
rank=config.mapping.rank,
gpus_per_node=config.mapping.gpus_per_node,
enable_attention_dp=True,
)
tp_size = 1
else:
self.mapping = config.mapping
tp_size = config.mapping.tp_size

d_inner = head_dim * nheads
d_in_proj = 2 * d_inner + 2 * n_groups * d_state + nheads
Expand All @@ -80,10 +93,6 @@ def __init__(
self.remove_padding = remove_padding
self.apply_silu = apply_silu

# tp
self.tp_size = tp_size
self.tp_rank = tp_rank

# paged state parameters
self.slot_mapping = None
self.is_paged_state = False
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
self.mamba_ssm_cache_dtype = ssm_cache_dtype

# get tp size
tp_size = mapping.tp_size
tp_size = mapping.tp_size if not mapping.enable_attention_dp else 1

# derive mamba parameters for conv and ssm states
d_inner = head_dim * num_heads
Expand Down
19 changes: 14 additions & 5 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5276,6 +5276,7 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
(4, 4, False, True, True),
(4, 1, False, False, True),
(4, 4, True, False, True),
(4, 4, True, True, True),
(4, 1, True, True, True),
(4, 4, False, True, False),
(4, 1, False, False, False),
Expand All @@ -5285,9 +5286,6 @@ class TestNemotronV3Super(LlmapiAccuracyTestHarness):
)
def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp,
overlap_scheduler, cuda_graph):
if attention_dp:
pytest.skip(
"Attention DP is not supported for Nemotron-3-Super yet")

kv_cache_config = KvCacheConfig(enable_block_reuse=False,
mamba_ssm_cache_dtype="float32")
Expand All @@ -5314,7 +5312,18 @@ def test_auto_dtype_4gpus(self, tp_size, ep_size, attention_dp,

@skip_pre_blackwell
@pytest.mark.skip_less_mpi_world_size(8)
def test_nvfp4_8gpus(self):
@pytest.mark.parametrize(
"attention_dp",
[
False,
True,
],
ids=[
"attention_dp_off",
"attention_dp_on",
],
)
def test_nvfp4_8gpus(self, attention_dp):
# Use this test to track the best performance config.
# The optimized config is still under investigation.
# Adding this test as placeholder.
Expand All @@ -5329,7 +5338,7 @@ def test_nvfp4_8gpus(self):
tensor_parallel_size=8,
moe_expert_parallel_size=8,
pipeline_parallel_size=1,
enable_attention_dp=False,
enable_attention_dp=attention_dp,
cuda_graph_config=CudaGraphConfig(max_batch_size=32,
enable_padding=True),
disable_overlap_scheduler=False,
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,9 @@ accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on]
accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_off]

# multimodal accuracy tests
accuracy/test_llm_api_pytorch_multimodal.py::TestQwen2_VL_7B::test_auto_dtype
Expand Down
5 changes: 3 additions & 2 deletions tests/integration/test_lists/test-db/l0_dgx_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ l0_dgx_b200:
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_mxfp4_mxfp8[enable_configurable_moe-True-8-64-TRTLLM]
- unittest/_torch/modules/test_fused_moe.py::test_fused_moe_wfp4a16[enable_configurable_moe-TRTLLM-2880-dtype0]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-False-True-True]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-True-True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_trtllm-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[dep4_latency_moe_cutlass-torch_compile=False]
Expand Down Expand Up @@ -85,7 +85,7 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV32::test_nvfp4_multi_gpus[baseline_mtp1] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype[False] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_nvfp4_8gpus[attention_dp_on] TIMEOUT (60)
- condition:
ranges:
system_gpu_count:
Expand Down Expand Up @@ -163,6 +163,7 @@ l0_dgx_b200:
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-False]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-False-False-True]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-4-True-False-True]
- accuracy/test_llm_api_pytorch.py::TestNemotronV3Super::test_auto_dtype_4gpus[4-1-True-True-True]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_trtllm-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_nvfp4[tep4_latency_moe_cutlass-torch_compile=True]
Expand Down