Skip to content

Commit 3209b31

Browse files
authored
feat: Custom masking utils for Gemma3 VLM (NVIDIA#5853)
Signed-off-by: Balaram Buddharaju <[email protected]>
1 parent 87fe44f commit 3209b31

File tree

6 files changed

+336
-25
lines changed

6 files changed

+336
-25
lines changed

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ..utils import get_global_attrs, get_model_extra_attrs
1616
from .interface import (AttentionBackend, AttentionMask, AttentionMetadata,
17-
PredefinedAttentionMask)
17+
CustomAttentionMask, PredefinedAttentionMask)
1818

1919
try:
2020
check_cuda_arch()
@@ -366,6 +366,12 @@ def _plan_with_params(self, plan_params: PlanParams) -> PlanParams:
366366
is_causal = plan_params.attention_mask_type == AttentionMaskType.causal
367367

368368
def prefill_plan():
369+
# Setting `window_left` to -1 for custom attention mask is important.
370+
# Else, FlashInfer proceeds to use SWA regardless of attention_mask_data.
371+
if plan_params.attention_mask_data is not None:
372+
window_left = -1
373+
else:
374+
window_left = plan_params.window_left
369375
prefill_wrapper.plan(
370376
self.qo_indptr[:self.num_contexts + 1],
371377
self.paged_kv_indptr_prefill[:self.num_contexts + 1],
@@ -377,9 +383,10 @@ def prefill_plan():
377383
self.page_size,
378384
causal=is_causal,
379385
sm_scale=plan_params.sm_scale,
380-
window_left=plan_params.window_left,
386+
window_left=window_left,
381387
q_data_type=plan_params.q_dtype,
382388
kv_data_type=plan_params.kv_dtype,
389+
custom_mask=plan_params.attention_mask_data,
383390
)
384391

385392
if plan_params in self._plan_params_to_wrappers:
@@ -473,8 +480,14 @@ def forward(self,
473480
*,
474481
attention_window_size: Optional[int] = None,
475482
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
483+
attention_mask_data: Optional[torch.Tensor] = None,
476484
**kwargs) -> torch.Tensor:
477-
if attention_mask == PredefinedAttentionMask.CAUSAL:
485+
if attention_mask == CustomAttentionMask.CUSTOM:
486+
assert attention_mask_data is not None, "attention_mask_data is required for custom attention mask."
487+
attention_mask_type = int(AttentionMaskType.custom_mask)
488+
attention_mask_data = attention_mask_data if attention_mask_data.ndim == 1 else attention_mask_data.flatten(
489+
)
490+
elif attention_mask == PredefinedAttentionMask.CAUSAL:
478491
attention_mask_type = int(AttentionMaskType.causal)
479492
attention_mask_data = None
480493
elif attention_mask == PredefinedAttentionMask.FULL:

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,14 @@ class PredefinedAttentionMask(str, Enum):
501501
FULL = "full"
502502

503503

504-
# May extend to custom attention mask type
505-
AttentionMask = Union[PredefinedAttentionMask]
504+
class CustomAttentionMask(str, Enum):
505+
"""
506+
Custom attention mask types
507+
"""
508+
CUSTOM = "custom"
509+
510+
511+
AttentionMask = Union[PredefinedAttentionMask, CustomAttentionMask]
506512

507513

508514
class AttentionBackend(Generic[TMetadata]):

tensorrt_llm/_torch/models/modeling_gemma3.py

Lines changed: 134 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from tensorrt_llm.functional import PositionEmbeddingType, RotaryScalingType
1111
from tensorrt_llm.mapping import Mapping
1212

13-
from ..attention_backend import AttentionMetadata
14-
from ..attention_backend.interface import (PositionalEmbeddingParams,
13+
from ..attention_backend import AttentionMetadata, FlashInferAttentionMetadata
14+
from ..attention_backend.interface import (AttentionMask, CustomAttentionMask,
15+
PositionalEmbeddingParams,
1516
PredefinedAttentionMask, RopeParams)
1617
from ..distributed import AllReduceParams
1718
from ..model_config import ModelConfig
@@ -101,14 +102,19 @@ def forward(
101102
position_ids: Optional[torch.IntTensor],
102103
hidden_states: torch.Tensor,
103104
attn_metadata: AttentionMetadata,
104-
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
105-
CAUSAL,
105+
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
106106
mrope_config: Optional[dict] = None,
107107
all_reduce_params: Optional[AllReduceParams] = None,
108108
lora_params: Optional[dict] = None,
109+
attention_mask_data: Optional[torch.Tensor] = None,
109110
**kwargs,
110111
) -> torch.Tensor:
111112

113+
if attention_mask_data is not None:
114+
assert isinstance(
115+
attn_metadata, FlashInferAttentionMetadata
116+
), "Only FlashInfer backend supports custom attention mask currently."
117+
assert attention_mask == CustomAttentionMask.CUSTOM
112118
return super().forward(position_ids=position_ids,
113119
hidden_states=hidden_states,
114120
attn_metadata=attn_metadata,
@@ -117,6 +123,7 @@ def forward(
117123
all_reduce_params=all_reduce_params,
118124
lora_params=lora_params,
119125
attention_window_size=self.attention_window_size,
126+
attention_mask_data=attention_mask_data,
120127
**kwargs)
121128

122129
def apply_qk_norm(self, q, k):
@@ -214,6 +221,7 @@ def forward(
214221
hidden_states: torch.Tensor,
215222
attn_metadata: AttentionMetadata,
216223
residual: Optional[torch.Tensor] = None,
224+
attention_mask_data: Optional[torch.Tensor] = None,
217225
**kwargs,
218226
) -> torch.Tensor:
219227

@@ -223,6 +231,9 @@ def forward(
223231
position_ids=position_ids,
224232
hidden_states=hidden_states,
225233
attn_metadata=attn_metadata,
234+
attention_mask=CustomAttentionMask.CUSTOM if attention_mask_data
235+
is not None else PredefinedAttentionMask.CAUSAL,
236+
attention_mask_data=attention_mask_data,
226237
**kwargs,
227238
)
228239
hidden_states = self.post_attention_layernorm(hidden_states)
@@ -267,6 +278,8 @@ def forward(
267278
input_ids: Optional[torch.IntTensor] = None,
268279
position_ids: Optional[torch.IntTensor] = None,
269280
inputs_embeds: Optional[torch.FloatTensor] = None,
281+
local_attention_mask_data: Optional[torch.Tensor] = None,
282+
global_attention_mask_data: Optional[torch.Tensor] = None,
270283
**kwargs,
271284
) -> torch.Tensor:
272285
if (input_ids is None) ^ (inputs_embeds is not None):
@@ -280,9 +293,13 @@ def forward(
280293
hidden_states = inputs_embeds.to(self.dtype)
281294

282295
for decoder_layer in self.layers:
283-
hidden_states = decoder_layer(position_ids=position_ids,
284-
hidden_states=hidden_states,
285-
attn_metadata=attn_metadata)
296+
hidden_states = decoder_layer(
297+
position_ids=position_ids,
298+
hidden_states=hidden_states,
299+
attn_metadata=attn_metadata,
300+
attention_mask_data=local_attention_mask_data
301+
if decoder_layer.self_attn.is_sliding else
302+
global_attention_mask_data)
286303

287304
hidden_states = self.norm(hidden_states)
288305
return hidden_states
@@ -301,21 +318,131 @@ def __init__(
301318
hidden_size=model_config.pretrained_config.hidden_size,
302319
vocab_size=model_config.pretrained_config.vocab_size)
303320

321+
def get_context_mask(
322+
self,
323+
image_token_mask: torch.BoolTensor,
324+
effective_sliding_window: Optional[int] = None,
325+
):
326+
"""
327+
Returns an attention mask such that text tokens attend to each other in causal fashion while image
328+
tokens attend in causal fashion as well as to all other image tokens in a bidirectional manner.
329+
Args:
330+
image_token_mask: A boolean tensor of shape (sequence_length,) where True indicates an image token.
331+
effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
332+
For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
333+
Returns:
334+
A boolean attention mask of shape (sequence_length, sequence_length).
335+
"""
336+
device = image_token_mask.device
337+
sequence_length = len(image_token_mask)
338+
if effective_sliding_window is None or effective_sliding_window >= sequence_length:
339+
causal_mask = torch.arange(
340+
sequence_length, device=device).unsqueeze(0) <= torch.arange(
341+
sequence_length, device=device).unsqueeze(1)
342+
else:
343+
attention_mask_1 = (torch.arange(sequence_length,
344+
device=device).unsqueeze(0)
345+
<= torch.arange(sequence_length,
346+
device=device).unsqueeze(1))
347+
attention_mask_2 = (
348+
torch.arange(sequence_length, device=device).unsqueeze(0)
349+
> torch.arange(sequence_length, device=device).unsqueeze(1) -
350+
effective_sliding_window)
351+
causal_mask = attention_mask_1 & attention_mask_2
352+
353+
# Apply a bidirectional mask for image tokens.
354+
token_type_ids = torch.zeros(sequence_length,
355+
dtype=torch.int32,
356+
device=device)
357+
# 1 for image tokens, 0 for text tokens.
358+
token_type_ids[image_token_mask] = 1
359+
token_type_mask = token_type_ids.unsqueeze(
360+
0) == token_type_ids.unsqueeze(1)
361+
# If text token, do not change anything.
362+
token_type_mask[token_type_ids == 0] = False
363+
causal_mask = causal_mask.masked_fill(token_type_mask, True)
364+
return causal_mask
365+
366+
# ASSUMPTIONS:
367+
# 1) Chunked prefill is disabled to avoid chunking image tokens as they need bidirectional attention.
368+
# 2) KV cache reuse is disabled to avoid partially matched image tokens (entire image must be reused to get things correct).
369+
def get_flashinfer_attention_mask(
370+
self,
371+
image_token_mask: torch.BoolTensor,
372+
attn_metadata: AttentionMetadata,
373+
effective_sliding_window: Optional[int] = None) -> torch.Tensor:
374+
"""
375+
This is specifically needed for context phase requests. Currently, we don't create custom mask for generation requests because FlashInfer backend
376+
doesn't use it anyway and there's nothing special we need to do for generation requests.
377+
- This function will only be called for a batch when there's at least one context request in the batch with image tokens.
378+
- In context phase, each sample's input_ids may have a mix of image tokens and text tokens where tokens corresponding to an image
379+
appear as a contiguous blob. Example: torch.IntTensor([2, 3, 4, 5, img_idx, img_idx, img_idx, ..., img_idx, 100])
380+
- While the text tokens attend to other tokens in a causal fashion, image tokens attend to others in a causal fashion and well as
381+
attend to other image tokens in a bidirectional manner. Hence, the need for custom masking.
382+
Args:
383+
image_token_mask: A boolean tensor of shape (len(input_ids),) where True indicates an image token. This corresponds to concatenated
384+
list of tokens for all samples in the batch.
385+
attn_metadata: The attention metadata for the batch.
386+
effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
387+
For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
388+
Returns:
389+
A flattened boolean mask of shape (sum(q_len[i] * k_len[i] for i in range(batch_size)).
390+
"""
391+
392+
assert isinstance(
393+
attn_metadata, FlashInferAttentionMetadata
394+
), "Only FlashInfer backend supports custom mask currently."
395+
num_contexts = attn_metadata.num_contexts
396+
assert num_contexts > 0, "There should be at least one context request in the batch for custom mask."
397+
398+
qo_indptr = attn_metadata.qo_indptr[:num_contexts + 1]
399+
cached_token_lens = attn_metadata.cached_token_lens[:num_contexts]
400+
assert (cached_token_lens == 0).all(
401+
), "cached_token_lens should be 0 for context requests since chunked prefill and kv cache reuse must be disabled."
402+
403+
# Create masks for context requests.
404+
context_mask_list = []
405+
for i in range(num_contexts):
406+
mask_i = self.get_context_mask(
407+
image_token_mask=image_token_mask[qo_indptr[i]:qo_indptr[i +
408+
1]],
409+
effective_sliding_window=effective_sliding_window,
410+
)
411+
context_mask_list.append(mask_i.flatten())
412+
return torch.cat(context_mask_list, dim=0).contiguous()
413+
304414
def forward(
305415
self,
306416
attn_metadata: AttentionMetadata,
307417
input_ids: torch.IntTensor = None,
308418
position_ids: Optional[torch.IntTensor] = None,
309419
inputs_embeds: Optional[torch.FloatTensor] = None,
310420
return_context_logits: bool = False,
421+
image_token_mask: Optional[torch.Tensor] = None,
311422
**kwargs,
312423
) -> torch.Tensor:
313424

425+
local_attention_mask_data = None
426+
global_attention_mask_data = None
427+
if image_token_mask is not None:
428+
global_attention_mask_data = self.get_flashinfer_attention_mask(
429+
image_token_mask=image_token_mask,
430+
attn_metadata=attn_metadata,
431+
effective_sliding_window=None,
432+
)
433+
local_attention_mask_data = self.get_flashinfer_attention_mask(
434+
image_token_mask=image_token_mask,
435+
attn_metadata=attn_metadata,
436+
effective_sliding_window=self.config.sliding_window,
437+
)
438+
314439
output = self.model(
315440
input_ids=input_ids,
316441
attn_metadata=attn_metadata,
317442
position_ids=position_ids,
318443
inputs_embeds=inputs_embeds,
444+
local_attention_mask_data=local_attention_mask_data,
445+
global_attention_mask_data=global_attention_mask_data,
319446
)
320447

321448
return self.logits_processor.forward(

tensorrt_llm/_torch/models/modeling_gemma3vl.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def __call__(
101101
input_ids = preprocess_outputs[0]["mm_processor_kwargs"]["input_ids"]
102102
mm_features = self._process(pixel_values)
103103
multimodal_data = {}
104-
multimodal_data["multimodal_embedding"] = mm_features
104+
multimodal_data["multimodal_embedding"] = mm_features.squeeze(dim=0)
105105
return input_ids[0].to(torch.int32).tolist(), {
106106
"multimodal_data": multimodal_data
107107
}
@@ -129,6 +129,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], *args,
129129

130130
self.model_config = model_config
131131
self.vocab_size = config.text_config.vocab_size
132+
self.sliding_window = config.text_config.sliding_window
132133
self.model_dtype = getattr(config.text_config, "torch_dtype",
133134
torch.float16)
134135
logger.info(f"[Gemma3Model::__init__]{self.dtype=} {self.model_dtype=}")
@@ -172,12 +173,24 @@ def forward(
172173
mm_embed
173174
) == num_context_requests, "Number of multimodal features (if provided) should be equal to number of context requests"
174175

176+
mm_token_ids = torch.tensor([self.image_token_index
177+
]).to(input_ids.device)
178+
mm_token_mask = None
179+
if len(mm_embed) > 0:
180+
# Get token type ids. 0 corresponds to text tokens, 1 corresponds to image tokens.
181+
mm_token_mask = torch.isin(input_ids, mm_token_ids)
175182
input_ids, inputs_embeds = fuse_input_embeds(
176183
embedding_layer=self.llm.model.embed_tokens,
177184
input_ids=input_ids,
178185
mm_embeds=mm_embed,
179186
mm_token_ids=torch.tensor([self.image_token_index
180187
]).to(input_ids.device))
181-
logits = self.llm.forward(attn_metadata, input_ids, position_ids,
182-
inputs_embeds, return_context_logits)
188+
logits = self.llm.forward(
189+
attn_metadata=attn_metadata,
190+
input_ids=input_ids,
191+
position_ids=position_ids,
192+
inputs_embeds=inputs_embeds,
193+
return_context_logits=return_context_logits,
194+
image_token_mask=mm_token_mask,
195+
)
183196
return logits

tensorrt_llm/_torch/modules/attention.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from ..attention_backend import (AttentionInputType, AttentionMetadata,
1313
TrtllmAttention, TrtllmAttentionMetadata)
14-
from ..attention_backend.interface import (PositionalEmbeddingParams,
14+
from ..attention_backend.interface import (AttentionMask,
15+
PositionalEmbeddingParams,
1516
PredefinedAttentionMask)
1617
from ..attention_backend.utils import create_attention, get_attention_backend
1718
from ..distributed import AllReduceParams
@@ -226,12 +227,12 @@ def forward(
226227
position_ids: Optional[torch.IntTensor],
227228
hidden_states: Union[torch.Tensor, Fp4QuantizedTensor],
228229
attn_metadata: AttentionMetadata,
229-
attention_mask: PredefinedAttentionMask = PredefinedAttentionMask.
230-
CAUSAL,
230+
attention_mask: AttentionMask = PredefinedAttentionMask.CAUSAL,
231231
mrope_config: Optional[dict] = None,
232232
all_reduce_params: Optional[AllReduceParams] = None,
233233
lora_params: Optional[dict] = None,
234234
attention_window_size: Optional[int] = None,
235+
attention_mask_data: Optional[torch.Tensor] = None,
235236
**kwargs,
236237
) -> torch.Tensor:
237238
"""
@@ -241,12 +242,12 @@ def forward(
241242
position_ids (Optional[torch.IntTensor]): The position IDs.
242243
hidden_states (torch.Tensor): The hidden states.
243244
attn_metadata (AttentionMetadata): The attention metadata.
244-
attention_mask (PredefinedAttentionMask): The attention mask type.
245+
attention_mask (AttentionMask): The attention mask type.
245246
mrope_config (Optional[dict]): The MROPE configuration.
246247
all_reduce_params (Optional[AllReduceParams]): The all reduce parameters.
247248
lora_params (Optional[dict]): The LoRA parameters.
248249
attention_window_size (Optional[int]): The attention window size.
249-
250+
attention_mask_data (Optional[torch.Tensor]): The attention mask data.
250251
Returns:
251252
torch.Tensor: The output tensor.
252253
"""
@@ -284,7 +285,8 @@ def forward(
284285
out_scale_sf=out_scale_sf,
285286
attention_mask=attention_mask,
286287
mrope_config=mrope_config,
287-
attention_window_size=attention_window_size)
288+
attention_window_size=attention_window_size,
289+
attention_mask_data=attention_mask_data)
288290
hidden_states = attn_output
289291
attn_output = self.o_proj(attn_output,
290292
all_reduce_params=all_reduce_params,

0 commit comments

Comments
 (0)