Skip to content

Commit 2bd09ed

Browse files
authored
fix: Skip rope scaling for local layers in Gemma3 VLM (NVIDIA#5857)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
1 parent c24eb67 commit 2bd09ed

File tree

3 files changed

+112
-18
lines changed

3 files changed

+112
-18
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import math
12
import os
23
import weakref
34
from dataclasses import dataclass, field
@@ -39,6 +40,8 @@ class PlanParams:
3940

4041
attention_mask_type: AttentionMaskType
4142
attention_mask_data: Optional[torch.Tensor] = None
43+
sm_scale: Optional[float] = None
44+
window_left: Optional[int] = None
4245

4346

4447
@dataclass(kw_only=True)
@@ -309,13 +312,23 @@ def plan(self,
309312
q_dtype: torch.dtype,
310313
kv_dtype: torch.dtype,
311314
attention_mask_type: int,
315+
q_scaling: Optional[float] = None,
316+
attention_window_size: Optional[int] = None,
312317
attention_mask_data: Optional[torch.Tensor] = None) -> PlanParams:
318+
319+
sm_scale = None
320+
if q_scaling is not None:
321+
sm_scale = 1 / (math.sqrt(head_dim) * q_scaling)
322+
313323
plan_params = PlanParams(
314324
num_heads=num_heads,
315325
num_kv_heads=num_kv_heads,
316326
head_dim=head_dim,
317327
q_dtype=q_dtype,
318328
kv_dtype=kv_dtype,
329+
sm_scale=sm_scale,
330+
window_left=attention_window_size
331+
if attention_window_size is not None else -1,
319332
attention_mask_type=AttentionMaskType(attention_mask_type),
320333
attention_mask_data=attention_mask_data)
321334
return self._plan_with_params(plan_params)
@@ -363,6 +376,8 @@ def prefill_plan():
363376
plan_params.head_dim,
364377
self.page_size,
365378
causal=is_causal,
379+
sm_scale=plan_params.sm_scale,
380+
window_left=plan_params.window_left,
366381
q_data_type=plan_params.q_dtype,
367382
kv_data_type=plan_params.kv_dtype,
368383
)
@@ -398,6 +413,8 @@ def decode_plan():
398413
plan_params.num_kv_heads,
399414
plan_params.head_dim,
400415
self.page_size,
416+
sm_scale=plan_params.sm_scale,
417+
window_left=plan_params.window_left,
401418
q_data_type=plan_params.q_dtype,
402419
kv_data_type=plan_params.kv_dtype,
403420
)
@@ -431,13 +448,15 @@ def __init__(
431448
head_dim: int,
432449
num_kv_heads: Optional[int] = None,
433450
quant_config: Optional[QuantConfig] = None,
451+
q_scaling: Optional[float] = None,
434452
skip_create_weights_in_init: bool = False,
435453
**kwargs,
436454
):
437455
super().__init__(layer_idx, num_heads, head_dim, num_kv_heads,
438456
quant_config, **kwargs)
439457
if not skip_create_weights_in_init:
440458
self.update_quant_config(self.quant_config)
459+
self.q_scaling = q_scaling
441460

442461
def update_quant_config(self, new_quant_config: Optional[QuantConfig]):
443462
self.quant_config = new_quant_config
@@ -452,6 +471,7 @@ def forward(self,
452471
v: Optional[torch.Tensor],
453472
metadata: FlashInferAttentionMetadata,
454473
*,
474+
attention_window_size: Optional[int] = None,
455475
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
456476
**kwargs) -> torch.Tensor:
457477
if attention_mask == PredefinedAttentionMask.CAUSAL:
@@ -463,10 +483,18 @@ def forward(self,
463483
else:
464484
raise ValueError("Unexpected attention mask type")
465485

466-
return forward_pattern(q, k, v, self.num_heads, self.head_dim,
467-
self.num_kv_heads, self.layer_idx,
468-
self.has_fp8_kv_cache, attention_mask_type,
469-
attention_mask_data)
486+
return forward_pattern(q=q,
487+
k=k,
488+
v=v,
489+
num_heads=self.num_heads,
490+
head_dim=self.head_dim,
491+
num_kv_heads=self.num_kv_heads,
492+
layer_idx=self.layer_idx,
493+
has_fp8_kv_cache=self.has_fp8_kv_cache,
494+
attention_mask_type=attention_mask_type,
495+
q_scaling=self.q_scaling,
496+
attention_mask_data=attention_mask_data,
497+
attention_window_size=attention_window_size)
470498

471499

472500
@torch.library.custom_op("trtllm::flashinfer_forward", mutates_args=())
@@ -480,7 +508,9 @@ def forward_pattern(
480508
layer_idx: int,
481509
has_fp8_kv_cache: bool,
482510
attention_mask_type: int,
483-
attention_mask_data: Optional[torch.Tensor],
511+
q_scaling: Optional[float] = None,
512+
attention_mask_data: Optional[torch.Tensor] = None,
513+
attention_window_size: Optional[int] = None,
484514
) -> torch.Tensor:
485515
'''
486516
Wrapping the flashinfer forward as a custom op is required to fix `torch.compile` graph breaks,
@@ -548,6 +578,8 @@ def decode_forward(plan_params: PlanParams):
548578
head_dim,
549579
q_dtype=q.dtype,
550580
kv_dtype=kv_cache.dtype,
581+
q_scaling=q_scaling,
582+
attention_window_size=attention_window_size,
551583
attention_mask_type=attention_mask_type,
552584
attention_mask_data=attention_mask_data)
553585

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from transformers import Gemma3TextConfig
88
from transformers.activations import ACT2FN
99

10-
from tensorrt_llm.functional import PositionEmbeddingType
10+
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
1111
from tensorrt_llm.mapping import Mapping
1212

1313
from ..attention_backend import AttentionMetadata
@@ -64,7 +64,9 @@ def __init__(
6464
rope_params = RopeParams.from_config(config)
6565
self.attention_window_size = None
6666
if is_sliding:
67-
rope_params.theta = 10000
67+
rope_params.theta = config.rope_local_base_freq
68+
rope_params.scale_type = RotaryScalingType.none
69+
rope_params.scale = 1.0
6870
self.attention_window_size = config.sliding_window - 1 # Gemma3 sliding window isn't inclusive.
6971
pos_embd_params = PositionalEmbeddingParams(
7072
type=PositionEmbeddingType.rope_gpt_neox,

tests/unittest/_torch/modeling/test_modeling_gemma3.py

Lines changed: 71 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
from tensorrt_llm.bindings.executor import KvCacheConfig
1818
from tensorrt_llm.mapping import Mapping
1919

20-
# This is copied from https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json.
21-
# Updated to have 1 local layer and 1 global layer. Sliding window size updated to 4.
22-
GEMMA3_1B_MINI_CONFIG = {
20+
GEMMA3_1B_CONFIG = {
2321
"architectures": ["Gemma3ForCausalLM"],
2422
"attention_bias": False,
2523
"attention_dropout": 0.0,
@@ -36,29 +34,68 @@
3634
"max_position_embeddings": 32768,
3735
"model_type": "gemma3_text",
3836
"num_attention_heads": 4,
39-
"num_hidden_layers": 2, # Modified for testing.
37+
"num_hidden_layers": 26,
4038
"num_key_value_heads": 1,
4139
"pad_token_id": 0,
4240
"query_pre_attn_scalar": 256,
4341
"rms_norm_eps": 1e-06,
4442
"rope_local_base_freq": 10000,
4543
"rope_scaling": None,
4644
"rope_theta": 1000000,
47-
"sliding_window": 4, # Modified for testing.
48-
"sliding_window_pattern": 2, # Modified for testing.
45+
"sliding_window": 512,
46+
"sliding_window_pattern": 6,
4947
"torch_dtype": "bfloat16",
5048
"transformers_version": "4.50.0.dev0",
5149
"use_cache": True,
5250
"vocab_size": 262144
5351
}
5452

53+
GEMMA3_27B_CONFIG = {
54+
"architectures": ["Gemma3ForConditionalGeneration"],
55+
"boi_token_index": 255999,
56+
"eoi_token_index": 256000,
57+
"eos_token_id": [1, 106],
58+
"image_token_index": 262144,
59+
"initializer_range": 0.02,
60+
"mm_tokens_per_image": 256,
61+
"model_type": "gemma3",
62+
"text_config": {
63+
"head_dim": 128,
64+
"hidden_size": 5376,
65+
"intermediate_size": 21504,
66+
"model_type": "gemma3_text",
67+
"num_attention_heads": 32,
68+
"num_hidden_layers": 62,
69+
"num_key_value_heads": 16,
70+
"query_pre_attn_scalar": 168,
71+
"rope_scaling": {
72+
"factor": 8.0,
73+
"rope_type": "linear"
74+
},
75+
"sliding_window": 1024
76+
},
77+
"torch_dtype": "bfloat16",
78+
"transformers_version": "4.50.0.dev0",
79+
"vision_config": {
80+
"hidden_size": 1152,
81+
"image_size": 896,
82+
"intermediate_size": 4304,
83+
"model_type": "siglip_vision_model",
84+
"num_attention_heads": 16,
85+
"num_hidden_layers": 27,
86+
"patch_size": 14,
87+
"vision_use_head": False
88+
}
89+
}
90+
5591

5692
@dataclass(repr=False)
5793
class Scenario:
5894
backend: str
95+
config_name: str
5996

6097
def __repr__(self) -> str:
61-
return f"backend:{self.backend.lower()}"
98+
return f"backend:{self.backend.lower()}_config:{self.config_name.lower()}"
6299

63100

64101
class TestGemma3(unittest.TestCase):
@@ -95,7 +132,8 @@ def get_kv_cache_manager(self, dtype: torch.dtype, config: Gemma3Config,
95132

96133
def test_gemma3_sanity(self):
97134

98-
config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG)
135+
config_dict = deepcopy(
136+
GEMMA3_1B_CONFIG) # Using 1B config for sanity test.
99137
gemma3_config = Gemma3Config.from_dict(config_dict)
100138

101139
dtype = gemma3_config.torch_dtype
@@ -174,8 +212,12 @@ def test_gemma3_sanity(self):
174212
kv_cache_manager.shutdown()
175213

176214
@parameterized.expand([
177-
Scenario(backend="TRTLLM"),
178-
Scenario(backend="VANILLA"),
215+
Scenario(backend="TRTLLM", config_name="1B"),
216+
Scenario(backend="VANILLA", config_name="1B"),
217+
Scenario(backend="FLASHINFER", config_name="1B"),
218+
Scenario(backend="TRTLLM", config_name="27B"),
219+
Scenario(backend="VANILLA", config_name="27B"),
220+
Scenario(backend="FLASHINFER", config_name="27B"),
179221
], lambda testcase_func, param_num, param:
180222
f"{testcase_func.__name__}[{param.args[0]}]")
181223
@torch.no_grad()
@@ -184,14 +226,31 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None:
184226
Compare output to HF.
185227
"""
186228
backend = scenario.backend
229+
config_name = scenario.config_name
187230
metadata_cls = get_attention_backend(backend).Metadata
188231

189232
torch.random.manual_seed(0)
190-
config_dict = deepcopy(GEMMA3_1B_MINI_CONFIG)
233+
234+
# Select the appropriate config based on the scenario
235+
if config_name == "1B":
236+
config_dict = deepcopy(GEMMA3_1B_CONFIG)
237+
elif config_name == "27B":
238+
config_dict = deepcopy(GEMMA3_27B_CONFIG)
239+
else:
240+
raise ValueError(f"Unknown config_name: {config_name}")
241+
191242
gemma3_config = Gemma3Config.from_dict(config_dict)
243+
if config_name == "27B":
244+
gemma3_config.text_config.torch_dtype = gemma3_config.torch_dtype
245+
gemma3_config = gemma3_config.text_config
192246
dtype = gemma3_config.torch_dtype
193247
device = torch.device('cuda')
194248

249+
# 2-layer network with one local (sliding window=4) and one global layer.
250+
gemma3_config.num_hidden_layers = 2
251+
gemma3_config.sliding_window = 4
252+
gemma3_config.sliding_window_pattern = 2
253+
195254
num_blocks = 1
196255
tokens_per_block = 128
197256
max_seq_len = num_blocks * tokens_per_block
@@ -253,6 +312,7 @@ def test_gemma3_allclose_to_hf(self, scenario: Scenario) -> None:
253312
position_ids=position_ids,
254313
past_key_values=hf_cache,
255314
use_cache=True)
315+
256316
torch.testing.assert_close(logits,
257317
ref.logits[:, -1].float(),
258318
atol=0.05,

0 commit comments

Comments
 (0)