1010from tensorrt_llm .functional import PositionEmbeddingType , RotaryScalingType
1111from tensorrt_llm .mapping import Mapping
1212
13- from ..attention_backend import AttentionMetadata
14- from ..attention_backend .interface import (PositionalEmbeddingParams ,
13+ from ..attention_backend import AttentionMetadata , FlashInferAttentionMetadata
14+ from ..attention_backend .interface import (AttentionMask , CustomAttentionMask ,
15+ PositionalEmbeddingParams ,
1516 PredefinedAttentionMask , RopeParams )
1617from ..distributed import AllReduceParams
1718from ..model_config import ModelConfig
@@ -101,14 +102,19 @@ def forward(
101102 position_ids : Optional [torch .IntTensor ],
102103 hidden_states : torch .Tensor ,
103104 attn_metadata : AttentionMetadata ,
104- attention_mask : PredefinedAttentionMask = PredefinedAttentionMask .
105- CAUSAL ,
105+ attention_mask : AttentionMask = PredefinedAttentionMask .CAUSAL ,
106106 mrope_config : Optional [dict ] = None ,
107107 all_reduce_params : Optional [AllReduceParams ] = None ,
108108 lora_params : Optional [dict ] = None ,
109+ attention_mask_data : Optional [torch .Tensor ] = None ,
109110 ** kwargs ,
110111 ) -> torch .Tensor :
111112
113+ if attention_mask_data is not None :
114+ assert isinstance (
115+ attn_metadata , FlashInferAttentionMetadata
116+ ), "Only FlashInfer backend supports custom attention mask currently."
117+ assert attention_mask == CustomAttentionMask .CUSTOM
112118 return super ().forward (position_ids = position_ids ,
113119 hidden_states = hidden_states ,
114120 attn_metadata = attn_metadata ,
@@ -117,6 +123,7 @@ def forward(
117123 all_reduce_params = all_reduce_params ,
118124 lora_params = lora_params ,
119125 attention_window_size = self .attention_window_size ,
126+ attention_mask_data = attention_mask_data ,
120127 ** kwargs )
121128
122129 def apply_qk_norm (self , q , k ):
@@ -214,6 +221,7 @@ def forward(
214221 hidden_states : torch .Tensor ,
215222 attn_metadata : AttentionMetadata ,
216223 residual : Optional [torch .Tensor ] = None ,
224+ attention_mask_data : Optional [torch .Tensor ] = None ,
217225 ** kwargs ,
218226 ) -> torch .Tensor :
219227
@@ -223,6 +231,9 @@ def forward(
223231 position_ids = position_ids ,
224232 hidden_states = hidden_states ,
225233 attn_metadata = attn_metadata ,
234+ attention_mask = CustomAttentionMask .CUSTOM if attention_mask_data
235+ is not None else PredefinedAttentionMask .CAUSAL ,
236+ attention_mask_data = attention_mask_data ,
226237 ** kwargs ,
227238 )
228239 hidden_states = self .post_attention_layernorm (hidden_states )
@@ -267,6 +278,8 @@ def forward(
267278 input_ids : Optional [torch .IntTensor ] = None ,
268279 position_ids : Optional [torch .IntTensor ] = None ,
269280 inputs_embeds : Optional [torch .FloatTensor ] = None ,
281+ local_attention_mask_data : Optional [torch .Tensor ] = None ,
282+ global_attention_mask_data : Optional [torch .Tensor ] = None ,
270283 ** kwargs ,
271284 ) -> torch .Tensor :
272285 if (input_ids is None ) ^ (inputs_embeds is not None ):
@@ -280,9 +293,13 @@ def forward(
280293 hidden_states = inputs_embeds .to (self .dtype )
281294
282295 for decoder_layer in self .layers :
283- hidden_states = decoder_layer (position_ids = position_ids ,
284- hidden_states = hidden_states ,
285- attn_metadata = attn_metadata )
296+ hidden_states = decoder_layer (
297+ position_ids = position_ids ,
298+ hidden_states = hidden_states ,
299+ attn_metadata = attn_metadata ,
300+ attention_mask_data = local_attention_mask_data
301+ if decoder_layer .self_attn .is_sliding else
302+ global_attention_mask_data )
286303
287304 hidden_states = self .norm (hidden_states )
288305 return hidden_states
@@ -301,21 +318,131 @@ def __init__(
301318 hidden_size = model_config .pretrained_config .hidden_size ,
302319 vocab_size = model_config .pretrained_config .vocab_size )
303320
321+ def get_context_mask (
322+ self ,
323+ image_token_mask : torch .BoolTensor ,
324+ effective_sliding_window : Optional [int ] = None ,
325+ ):
326+ """
327+ Returns an attention mask such that text tokens attend to each other in causal fashion while image
328+ tokens attend in causal fashion as well as to all other image tokens in a bidirectional manner.
329+ Args:
330+ image_token_mask: A boolean tensor of shape (sequence_length,) where True indicates an image token.
331+ effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
332+ For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
333+ Returns:
334+ A boolean attention mask of shape (sequence_length, sequence_length).
335+ """
336+ device = image_token_mask .device
337+ sequence_length = len (image_token_mask )
338+ if effective_sliding_window is None or effective_sliding_window >= sequence_length :
339+ causal_mask = torch .arange (
340+ sequence_length , device = device ).unsqueeze (0 ) <= torch .arange (
341+ sequence_length , device = device ).unsqueeze (1 )
342+ else :
343+ attention_mask_1 = (torch .arange (sequence_length ,
344+ device = device ).unsqueeze (0 )
345+ <= torch .arange (sequence_length ,
346+ device = device ).unsqueeze (1 ))
347+ attention_mask_2 = (
348+ torch .arange (sequence_length , device = device ).unsqueeze (0 )
349+ > torch .arange (sequence_length , device = device ).unsqueeze (1 ) -
350+ effective_sliding_window )
351+ causal_mask = attention_mask_1 & attention_mask_2
352+
353+ # Apply a bidirectional mask for image tokens.
354+ token_type_ids = torch .zeros (sequence_length ,
355+ dtype = torch .int32 ,
356+ device = device )
357+ # 1 for image tokens, 0 for text tokens.
358+ token_type_ids [image_token_mask ] = 1
359+ token_type_mask = token_type_ids .unsqueeze (
360+ 0 ) == token_type_ids .unsqueeze (1 )
361+ # If text token, do not change anything.
362+ token_type_mask [token_type_ids == 0 ] = False
363+ causal_mask = causal_mask .masked_fill (token_type_mask , True )
364+ return causal_mask
365+
366+ # ASSUMPTIONS:
367+ # 1) Chunked prefill is disabled to avoid chunking image tokens as they need bidirectional attention.
368+ # 2) KV cache reuse is disabled to avoid partially matched image tokens (entire image must be reused to get things correct).
369+ def get_flashinfer_attention_mask (
370+ self ,
371+ image_token_mask : torch .BoolTensor ,
372+ attn_metadata : AttentionMetadata ,
373+ effective_sliding_window : Optional [int ] = None ) -> torch .Tensor :
374+ """
375+ This is specifically needed for context phase requests. Currently, we don't create custom mask for generation requests because FlashInfer backend
376+ doesn't use it anyway and there's nothing special we need to do for generation requests.
377+ - This function will only be called for a batch when there's at least one context request in the batch with image tokens.
378+ - In context phase, each sample's input_ids may have a mix of image tokens and text tokens where tokens corresponding to an image
379+ appear as a contiguous blob. Example: torch.IntTensor([2, 3, 4, 5, img_idx, img_idx, img_idx, ..., img_idx, 100])
380+ - While the text tokens attend to other tokens in a causal fashion, image tokens attend to others in a causal fashion and well as
381+ attend to other image tokens in a bidirectional manner. Hence, the need for custom masking.
382+ Args:
383+ image_token_mask: A boolean tensor of shape (len(input_ids),) where True indicates an image token. This corresponds to concatenated
384+ list of tokens for all samples in the batch.
385+ attn_metadata: The attention metadata for the batch.
386+ effective_sliding_window: The effective sliding window size for the attention mask. Default is None, which means no sliding window.
387+ For Gemma3, this is the sliding window size from config (e.g. 512 for 1B model).
388+ Returns:
389+ A flattened boolean mask of shape (sum(q_len[i] * k_len[i] for i in range(batch_size)).
390+ """
391+
392+ assert isinstance (
393+ attn_metadata , FlashInferAttentionMetadata
394+ ), "Only FlashInfer backend supports custom mask currently."
395+ num_contexts = attn_metadata .num_contexts
396+ assert num_contexts > 0 , "There should be at least one context request in the batch for custom mask."
397+
398+ qo_indptr = attn_metadata .qo_indptr [:num_contexts + 1 ]
399+ cached_token_lens = attn_metadata .cached_token_lens [:num_contexts ]
400+ assert (cached_token_lens == 0 ).all (
401+ ), "cached_token_lens should be 0 for context requests since chunked prefill and kv cache reuse must be disabled."
402+
403+ # Create masks for context requests.
404+ context_mask_list = []
405+ for i in range (num_contexts ):
406+ mask_i = self .get_context_mask (
407+ image_token_mask = image_token_mask [qo_indptr [i ]:qo_indptr [i +
408+ 1 ]],
409+ effective_sliding_window = effective_sliding_window ,
410+ )
411+ context_mask_list .append (mask_i .flatten ())
412+ return torch .cat (context_mask_list , dim = 0 ).contiguous ()
413+
304414 def forward (
305415 self ,
306416 attn_metadata : AttentionMetadata ,
307417 input_ids : torch .IntTensor = None ,
308418 position_ids : Optional [torch .IntTensor ] = None ,
309419 inputs_embeds : Optional [torch .FloatTensor ] = None ,
310420 return_context_logits : bool = False ,
421+ image_token_mask : Optional [torch .Tensor ] = None ,
311422 ** kwargs ,
312423 ) -> torch .Tensor :
313424
425+ local_attention_mask_data = None
426+ global_attention_mask_data = None
427+ if image_token_mask is not None :
428+ global_attention_mask_data = self .get_flashinfer_attention_mask (
429+ image_token_mask = image_token_mask ,
430+ attn_metadata = attn_metadata ,
431+ effective_sliding_window = None ,
432+ )
433+ local_attention_mask_data = self .get_flashinfer_attention_mask (
434+ image_token_mask = image_token_mask ,
435+ attn_metadata = attn_metadata ,
436+ effective_sliding_window = self .config .sliding_window ,
437+ )
438+
314439 output = self .model (
315440 input_ids = input_ids ,
316441 attn_metadata = attn_metadata ,
317442 position_ids = position_ids ,
318443 inputs_embeds = inputs_embeds ,
444+ local_attention_mask_data = local_attention_mask_data ,
445+ global_attention_mask_data = global_attention_mask_data ,
319446 )
320447
321448 return self .logits_processor .forward (
0 commit comments