Skip to content

Commit 2f3e3ae

Browse files
authored
[https://nvbugs/5516710][fix] fix Llama 3.3 TP PP case (#7717)
Signed-off-by: Yan Chunwei <[email protected]>
1 parent 015e149 commit 2f3e3ae

File tree

4 files changed

+35
-4
lines changed

4 files changed

+35
-4
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,8 @@ def forward(
560560
# Adjust the scale and fusion pattern.
561561
if self.next_attn is not None and (self.is_nvfp4
562562
or self.is_fp8_quant):
563-
scale = self.next_attn.qkv_proj.input_scale
563+
scale = self.next_attn.qkv_proj.input_scale if hasattr(
564+
self.next_attn.qkv_proj, 'input_scale') else None
564565
else:
565566
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
566567
scale = None
@@ -769,7 +770,8 @@ def forward(
769770
# Adjust the scale and fusion pattern.
770771
if self.next_attn is not None and (self.is_nvfp4
771772
or self.is_fp8_quant):
772-
scale = self.next_attn.qkv_proj.input_scale
773+
scale = self.next_attn.qkv_proj.input_scale if hasattr(
774+
self.next_attn.qkv_proj, 'input_scale') else None
773775
else:
774776
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
775777
scale = None

tensorrt_llm/_torch/models/modeling_llama_min_latency.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,8 @@ def forward(
826826
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8,
827827
residual=residual,
828828
norm_weight=self.next_layer_layernorm.weight,
829-
scale=self.next_attn.qkv_proj.input_scale,
829+
scale=self.next_attn.qkv_proj.input_scale if hasattr(
830+
self.next_attn.qkv_proj, 'input_scale') else None,
830831
eps=self.next_layer_layernorm.variance_epsilon,
831832
))
832833
elif use_fp4_allreduce and self.next_attn is not None:
@@ -837,7 +838,8 @@ def forward(
837838
RESIDUAL_RMS_NORM_QUANT_NVFP4,
838839
residual=residual,
839840
norm_weight=self.next_layer_layernorm.weight,
840-
scale=self.next_attn.qkv_proj.input_scale,
841+
scale=self.next_attn.qkv_proj.input_scale if hasattr(
842+
self.next_attn.qkv_proj, 'input_scale') else None,
841843
eps=self.next_layer_layernorm.variance_epsilon,
842844
))
843845
else:

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ l0_dgx_b200:
1515
backend: pytorch
1616
tests:
1717
- unittest/_torch/multi_gpu_modeling -k "deepseek"
18+
- unittest/_torch/multi_gpu_modeling/test_llama3.py::test_llama_3_3
1819
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp4-attn_backend=FLASHINFER-torch_compile=False]
1920
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=TRTLLM-torch_compile=False]
2021
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[tp4-fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from utils.llm_data import llm_models_root
2+
from utils.util import similar
3+
4+
from tensorrt_llm import LLM
5+
6+
7+
def test_llama_3_3():
8+
model_dir = llm_models_root(
9+
) / "llama-3.3-models" / "Llama-3.3-70B-Instruct-FP8"
10+
tp = 2
11+
pp = 2
12+
13+
llm = LLM(model_dir, tensor_parallel_size=tp, pipeline_parallel_size=pp)
14+
prompts = [
15+
"The capital of France is",
16+
"The president of the United States is",
17+
]
18+
19+
outputs = llm.generate(prompts)
20+
21+
expected_outputs = [
22+
" a city of romance, art, fashion, and cuisine. Paris, also known as the City of Light, is a must-visit destination for anyone interested in",
23+
" the head of state and head of government of the United States. The president is also the commander-in-chief of the armed forces. The president is elected by the",
24+
]
25+
for i, output in enumerate(outputs):
26+
assert similar(output.outputs[0].text, expected_outputs[i])

0 commit comments

Comments
 (0)