Skip to content

Commit f3dd6da

Browse files
authored
[#10056][chore] AutoDeploy: Enable Nemo SuperV3 accuracy test (#10308)
Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent 5e0e481 commit f3dd6da

File tree

5 files changed

+38
-9
lines changed

5 files changed

+38
-9
lines changed

tensorrt_llm/_torch/auto_deploy/models/custom/modeling_nemotron_h.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,14 +250,23 @@ def forward(self, hidden_states):
250250

251251
# Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
252252
class NemotronHMLP(nn.Module):
253-
def __init__(self, config, layer_idx: int, intermediate_size: Optional[int] = None):
253+
def __init__(
254+
self,
255+
config,
256+
layer_idx: int,
257+
intermediate_size: Optional[int] = None,
258+
is_expert: bool = False,
259+
):
254260
super().__init__()
255261
self.config = config
256262
self.layer_idx = layer_idx
257263
self.hidden_size = config.hidden_size
258264
self.intermediate_size = intermediate_size or config.intermediate_size
259-
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
260-
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
265+
# Use latent size for expert MLPs if provided by config (required for SuperV3)
266+
use_latent_size = (getattr(self.config, "moe_latent_size", None) is not None) and is_expert
267+
input_size = self.config.moe_latent_size if use_latent_size else self.hidden_size
268+
self.up_proj = nn.Linear(input_size, self.intermediate_size, bias=config.mlp_bias)
269+
self.down_proj = nn.Linear(self.intermediate_size, input_size, bias=config.mlp_bias)
261270
self.act_fn = ACT2FN[config.mlp_hidden_act]
262271

263272
def forward(self, x):
@@ -271,7 +280,10 @@ def __init__(self, config, layer_idx: Optional[int] = None):
271280
self.experts = nn.ModuleList(
272281
[
273282
NemotronHMLP(
274-
config, intermediate_size=config.moe_intermediate_size, layer_idx=layer_idx
283+
config,
284+
layer_idx=layer_idx,
285+
intermediate_size=config.moe_intermediate_size,
286+
is_expert=True,
275287
)
276288
for _ in range(config.n_routed_experts)
277289
]
@@ -281,7 +293,19 @@ def __init__(self, config, layer_idx: Optional[int] = None):
281293
config=config,
282294
intermediate_size=config.moe_shared_expert_intermediate_size,
283295
layer_idx=layer_idx,
296+
is_expert=False,
284297
)
298+
# Add latent projections when using latent MoE (required for SuperV3)
299+
if getattr(config, "moe_latent_size", None) is not None:
300+
self.fc1_latent_proj = nn.Linear(
301+
config.hidden_size, config.moe_latent_size, bias=config.mlp_bias
302+
)
303+
self.fc2_latent_proj = nn.Linear(
304+
config.moe_latent_size, config.hidden_size, bias=config.mlp_bias
305+
)
306+
else:
307+
self.fc1_latent_proj = nn.Identity()
308+
self.fc2_latent_proj = nn.Identity()
285309

286310
def forward(self, hidden_states: torch.Tensor):
287311
residuals = hidden_states

tests/integration/defs/accuracy/test_llm_api_autodeploy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_fp8(self):
235235

236236
class TestNemotronSuperV3(LlmapiAccuracyTestHarness):
237237
MODEL_NAME = "nvidia/Nemotron-Super-V3"
238-
MODEL_PATH_BF16 = "/scratch/models/super-v3-iter_0440000/hf" # add to llm_models_root? I don't have permissions
238+
MODEL_PATH_BF16 = f"{llm_models_root()}/Nemotron-Super-3-120B-A12B-dev"
239239

240240
def get_default_kwargs(self):
241241
return {
@@ -264,15 +264,15 @@ def get_default_sampling_params(self):
264264
n=beam_width,
265265
use_beam_search=beam_width > 1)
266266

267-
@pytest.mark.skip_less_device_memory(
268-
32000) # might need to require more memory
269-
@pytest.mark.skip_less_device(8)
267+
# 180GB works, might be able to go lower
268+
@pytest.mark.skip_less_device_memory(180000)
269+
@pytest.mark.skip_less_device(4)
270270
def test_bf16(self):
271271
kwargs = self.get_default_kwargs()
272272
sampling_params = self.get_default_sampling_params()
273273
with AutoDeployLLM(model=self.MODEL_PATH_BF16,
274274
tokenizer=self.MODEL_PATH_BF16,
275-
world_size=8,
275+
world_size=4,
276276
**kwargs) as llm:
277277
task = MMLU(self.MODEL_NAME)
278278
task.evaluate(llm, sampling_params=sampling_params)

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ l0_dgx_b200:
2828
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ucx[DeepSeek-V3-Lite-fp8]
2929
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
3030
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[latency_adp_lmtp_tp4]
31+
# ------------- AutoDeploy tests ---------------
32+
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
3133
- condition:
3234
ranges:
3335
system_gpu_count:

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ l0_dgx_h100:
124124
- disaggregated/test_auto_scaling.py::test_worker_restart[http-load_balancing]
125125
- disaggregated/test_auto_scaling.py::test_minimal_instances[http-round_robin]
126126
- disaggregated/test_auto_scaling.py::test_disagg_server_restart[http-round_robin]
127+
# ------------- AutoDeploy tests ---------------
128+
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
127129
- condition:
128130
ranges:
129131
system_gpu_count:

tests/integration/test_lists/test-db/l0_dgx_h200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ l0_dgx_h200:
134134
# ------------- AutoDeploy tests ---------------
135135
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-4]
136136
- accuracy/test_llm_api_autodeploy.py::TestNemotronMOE::test_bf16
137+
- accuracy/test_llm_api_autodeploy.py::TestNemotronSuperV3::test_bf16
137138
- condition:
138139
ranges:
139140
system_gpu_count:

0 commit comments

Comments
 (0)