Skip to content

Commit 870da18

Browse files
hyuknmikeiovine
authored andcommitted
[https://nvbugs/5608489][fix] Fix output unpack issues for Llama3/4 NVFP4 models. (#8679)
Signed-off-by: Yukun He <[email protected]> Signed-off-by: Mike Iovine <[email protected]>
1 parent a5d39f5 commit 870da18

File tree

5 files changed

+8
-8
lines changed

5 files changed

+8
-8
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ def forward(
599599
))
600600

601601
# Unpack the allreduce output
602-
if self.next_attn is not None and self.is_nvfp4:
602+
if self.post_feed_forward_fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
603603
act_fp4, act_sf, residual = allreduce_output
604604
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
605605
else:
@@ -790,7 +790,7 @@ def forward(
790790
scale=scale,
791791
eps=self.next_layer_layernorm.variance_epsilon,
792792
))
793-
if self.next_attn is not None and self.is_nvfp4:
793+
if self.post_mlp_fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4:
794794
act_fp4, act_sf, residual = all_reduce_output
795795
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
796796
else:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -652,15 +652,15 @@ def test_nvfp4_tp4(self):
652652

653653
@pytest.mark.skip_less_device(4)
654654
@skip_pre_blackwell
655-
def test_fp8_tp2pp2(self):
656-
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP8"
655+
def test_fp4_tp2pp2(self):
656+
model_path = f"{llm_models_root()}/llama-3.3-models/Llama-3.3-70B-Instruct-FP4"
657657
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.5)
658658
with LLM(model_path,
659659
tensor_parallel_size=2,
660660
pipeline_parallel_size=2,
661661
max_batch_size=32,
662662
kv_cache_config=kv_cache_config) as llm:
663-
assert llm.args.quant_config.quant_algo == QuantAlgo.FP8
663+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
664664
sampling_params = SamplingParams(
665665
max_tokens=256,
666666
temperature=0.0,

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,7 +417,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[
417417
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=False]
418418
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_beam_search[enable_cuda_graph=True-enable_padding=True-disable_overlap_scheduler=True]
419419
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
420-
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp2pp2
420+
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2
421421
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
422422
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True]
423423
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=False]

tests/integration/test_lists/qa/llm_function_core_sanity.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_auto_dtype
129129
accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_prequantized
130130
accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_auto_dtype
131131
accuracy/test_llm_api_pytorch.py::TestLlama3_2_3B::test_fp8_prequantized
132-
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp2pp2
132+
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2
133133
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
134134
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
135135
accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_eagle3_tp8[eagle3_one_model=True]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ l0_dgx_b200:
5757
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[llama-3.1-8b-instruct-hf-fp8]
5858
- disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_nixl[DeepSeek-V3-Lite-fp8]
5959
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp4
60-
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp8_tp2pp2
60+
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2
6161
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_nvfp4_tp4
6262
- accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-cutlass-auto]
6363
- condition:

0 commit comments

Comments
 (0)