Skip to content

Commit 8cfb363

Browse files
committed
Update llama torch model + llm api pytorch test
Signed-off-by: jintaop <jintaop@nvidia.com>
1 parent 2790927 commit 8cfb363

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def __init__(
635635
)
636636
self.is_nvfp4 = self.is_quanted and model_config.quant_config.quant_mode.has_nvfp4(
637637
)
638-
638+
# Self Attention
639639
self.self_attn = LlamaAttention(
640640
model_config,
641641
layer_idx=layer_idx,

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ def test_nvfp4(self):
9595
task.evaluate(llm)
9696
task = MMLU(self.MODEL_NAME)
9797
task.evaluate(llm)
98+
99+
def test_nvfp4_with_norm_quant(self, monkeypatch):
100+
model_path = f"{llm_models_root()}/nvfp4-quantized/Meta-Llama-3.1-8B"
101+
with LLM(model_path) as llm:
102+
sm_version = get_sm_version()
103+
if sm_version not in (100, 103):
104+
pytest.skip(f"test_nvfp4_with_norm_quant supports SM 100 and 103 only")
105+
monkeypatch.setenv("TRTLLM_DISABLE_NVFP4_LAYERNORM_FUSION", "0")
106+
assert llm.args.quant_config.quant_algo == QuantAlgo.NVFP4
107+
task = CnnDailymail(self.MODEL_NAME)
108+
task.evaluate(llm)
109+
task = MMLU(self.MODEL_NAME)
110+
task.evaluate(llm)
98111

99112
@skip_pre_blackwell
100113
@pytest.mark.parametrize("stream_interval", [4, 64],

0 commit comments

Comments
 (0)