Skip to content

Commit 830122c

Browse files
JunyiXu-nvmikeiovine
authored andcommitted
[https://nvbugs/5569713][fix] Disable fp8 deep gemm for EXAONE-4.0-32B-FP8 (NVIDIA#8429)
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com> Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
1 parent 49b7e63 commit 830122c

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

tensorrt_llm/_torch/models/modeling_exaone4.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from tensorrt_llm._torch.modules.qk_norm_attention import QKNormRoPEAttention
77
from tensorrt_llm.functional import PositionEmbeddingType
8+
from tensorrt_llm.quantization import QuantAlgo
89

910
from ..attention_backend import AttentionMetadata
1011
from ..attention_backend.interface import (PositionalEmbeddingParams,
@@ -54,7 +55,8 @@ class Exaone4Attention(QKNormRoPEAttention):
5455
def __init__(self,
5556
model_config: ModelConfig[Exaone4Config],
5657
layer_idx: Optional[int] = None,
57-
fuse_qk_norm_rope: bool = False):
58+
fuse_qk_norm_rope: bool = False,
59+
disable_deep_gemm: bool = False):
5860
config = model_config.pretrained_config
5961

6062
self.attention_window_size = None
@@ -88,6 +90,7 @@ def __init__(self,
8890
layer_idx=layer_idx,
8991
dtype=config.torch_dtype,
9092
config=model_config,
93+
disable_deep_gemm=disable_deep_gemm,
9194
)
9295

9396
def forward(
@@ -128,9 +131,17 @@ def __init__(
128131
self.is_quanted = model_config.quant_config and model_config.quant_config.quant_mode.has_any_quant(
129132
)
130133

134+
disable_deep_gemm = False
135+
quant_config = getattr(model_config, "quant_config", None)
136+
if quant_config is not None:
137+
# EXAONE4 fp8 has an illegal memory access issue with deep_gemm.
138+
disable_deep_gemm = getattr(quant_config, "quant_algo",
139+
None) == QuantAlgo.FP8_BLOCK_SCALES
140+
131141
self.self_attn = Exaone4Attention(
132142
model_config,
133143
layer_idx=layer_idx,
144+
disable_deep_gemm=disable_deep_gemm,
134145
)
135146

136147
self.mlp = GatedMLP(
@@ -140,6 +151,7 @@ def __init__(
140151
dtype=config.torch_dtype,
141152
config=model_config,
142153
layer_idx=layer_idx,
154+
disable_deep_gemm=disable_deep_gemm,
143155
)
144156

145157
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ l0_b200:
7878
- unittest/_torch/modeling -k "modeling_gpt_oss"
7979
- unittest/tools/test_layer_wise_benchmarks.py::test_deepseek_r1_ctx_dep[1]
8080
- unittest/tools/test_layer_wise_benchmarks.py::test_qwen3_next_gen_tep[1]
81+
- unittest/_torch/modeling/test_modeling_exaone4.py::TestEXAONE4::test_llm_load_1_FP8
8182
# ------------- AutoDeploy tests ---------------
8283
- accuracy/test_llm_api_autodeploy.py::TestLlama3_1_8B::test_auto_dtype[False-1]
8384
- unittest/_torch/auto_deploy/unit/singlegpu

tests/unittest/_torch/modeling/test_modeling_exaone4.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import json
2+
import os
3+
import shutil
14
import unittest
25
from copy import deepcopy
36
from dataclasses import dataclass
@@ -50,8 +53,9 @@ class Exaone4Config(PretrainedConfig):
5053
"max_position_embeddings": 131072,
5154
"model_type": "exaone4",
5255
"num_attention_heads": 40,
53-
"num_hidden_layers":
54-
4, #NOTE: For testing, we use 4 instead of 64(all layers)
56+
# NOTE: For testing, we use 32 instead of 64(all layers)
57+
# Increase from 4 to 32 to trigger the deep_gemm kernel issue
58+
"num_hidden_layers": 32,
5559
"num_key_value_heads": 8,
5660
"pad_token_id": 0,
5761
"rms_norm_eps": 1e-05,
@@ -73,6 +77,15 @@ class Exaone4Config(PretrainedConfig):
7377
"attn_implementation": "flash_attention_2"
7478
}
7579

80+
EXAONE4_FP8_QUANT_CONFIG = {
81+
"quantization_config": {
82+
"activation_scheme": "dynamic",
83+
"modules_to_not_convert": None,
84+
"quant_method": "fp8",
85+
"weight_block_size": [128, 128]
86+
},
87+
}
88+
7689

7790
@dataclass(repr=False)
7891
class Scenario:
@@ -387,3 +400,30 @@ def run_forward(input_ids, position_ids, attn_metadata):
387400
if graph_runner is not None:
388401
graph_runner.clear()
389402
kv_cache_manager.shutdown()
403+
404+
@parameterized.expand([None, "FP8"])
405+
def test_llm_load(self, quant_algo):
406+
407+
def dump_config_json(dst_dir, config):
408+
if os.path.exists(dst_dir):
409+
shutil.rmtree(dst_dir)
410+
os.makedirs(dst_dir)
411+
412+
dst_path = os.path.join(dst_dir, 'config.json')
413+
with open(dst_path, 'w', encoding='utf-8') as f:
414+
json.dump(config, f, indent=2, ensure_ascii=False)
415+
416+
config_dict = deepcopy(EXAONE4_SINGLE_LAYER_CONFIG)
417+
if quant_algo == "FP8":
418+
if getSMVersion() < 89:
419+
self.skipTest(
420+
"This test is not supported in pre-Ada architecture")
421+
422+
config_dict.update(EXAONE4_FP8_QUANT_CONFIG)
423+
424+
tmp_model_dir = f"/tmp/exaone4_llm_load_test_model"
425+
dump_config_json(tmp_model_dir, config_dict)
426+
try:
427+
tensorrt_llm.LLM(model=tmp_model_dir, load_format="dummy")
428+
except Exception:
429+
raise RuntimeError("Failed to load model.")

0 commit comments

Comments
 (0)