Skip to content

Commit 095b7a3

Browse files
authored
[https://nvbugs/5521253][fix] Enable Gemma3 12B & 27B on SM100 (#8666)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 9f1d274 commit 095b7a3

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

cpp/kernels/fmha_v2/setup.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6379,6 +6379,16 @@ def enumerate_kernels():
63796379
and kspec.version == 2
63806380
and kspec.cross_mha == False
63816381
and kspec.flash_attention == False)
6382+
# Gemma3 VL support.
6383+
or (kspec.sm == 100
6384+
and kspec.dtype in ['fp16', 'bf16', 'fp16_fp32', 'e4m3', 'e4m3_fp32']
6385+
and kspec.head_size == 72
6386+
and kspec.head_size_v == 0
6387+
and kspec.sage_block_sizes is None
6388+
and kspec.version == 2
6389+
and kspec.cross_mha == False
6390+
and kspec.flash_attention == True
6391+
and kspec.input_layout != InputLayout.SEPARATE_Q_K_V)
63826392
# Deepseek MLA (generation 576/512 paged)
63836393
or (kspec.sm in [90, 100, 120]
63846394
and kspec.dtype in ['bf16', 'e4m3_fp32']

cpp/tensorrt_llm/kernels/fmhaDispatcher.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,10 @@ QkvLayout AttentionInputLayoutToQkvLayout(AttentionInputLayout layout)
4646

4747
FmhaDispatcher::FmhaDispatcher(MHARunnerFixedParams fixedParams)
4848
: mFixedParams(fixedParams)
49-
, mUseTllmGen(tensorrt_llm::common::isSM100Family())
49+
// TRTLLM-GEN only supports power of 2 head sizes.
50+
// The exception will fall back to fmha v2.
51+
// Please update fmha_v2/setup.py if you want to add more supported head sizes.
52+
, mUseTllmGen(tensorrt_llm::common::isSM100Family() && fixedParams.headSize != 72)
5053
{
5154
if (mUseTllmGen)
5255
{

tests/integration/defs/test_e2e.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2447,8 +2447,7 @@ def test_ptp_quickstart_advanced_mixed_precision(llm_root, llm_venv):
24472447
marks=pytest.mark.skip_less_device_memory(80000)),
24482448
pytest.param("gemma-3-27b-it",
24492449
"gemma/gemma-3-27b-it",
2450-
marks=(pytest.mark.skip_less_device_memory(80000),
2451-
skip_post_blackwell)),
2450+
marks=pytest.mark.skip_less_device_memory(80000)),
24522451
pytest.param(
24532452
"Nano-v2-VLM",
24542453
"Nano-v2-VLM",

0 commit comments

Comments
 (0)