Skip to content

Commit ad81dff

Browse files
committed
add generic create_attention_mask custom op
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent 85662cf commit ad81dff

File tree

9 files changed

+445
-311
lines changed

9 files changed

+445
-311
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_gemma3_mask.py

Lines changed: 0 additions & 189 deletions
This file was deleted.
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
"""Generic VLM attention mask generation ops.
2+
3+
This module provides custom ops for generating attention masks for VLM models.
4+
The ops are model-agnostic - they dispatch to model-specific generators
5+
registered in VlmMaskGeneratorRegistry.
6+
7+
Key features:
8+
- Generic dispatcher op that routes to model-specific mask generators
9+
- Model-specific mask creation logic is isolated in registered generators
10+
- sliding_window parameter (backend may use native sliding window instead)
11+
12+
The masks are generated in the flattened format expected by FlashInfer's
13+
BatchPrefillWithPagedKVCacheWrapper: a 1D boolean tensor of shape
14+
[sum(q_len[i] * k_len[i]) for context sequences].
15+
"""
16+
17+
from typing import List
18+
19+
import torch
20+
from torch import Tensor
21+
22+
from .vlm_mask_registry import VlmMaskGeneratorRegistry
23+
24+
# =============================================================================
25+
# Generic dispatcher op - routes to model-specific generators
26+
# =============================================================================
27+
28+
29+
@torch.library.custom_op("auto_deploy::create_attention_mask", mutates_args=())
30+
def create_attention_mask(
31+
token_info: Tensor,
32+
qo_indptr: Tensor,
33+
seq_len: Tensor,
34+
sliding_window: int,
35+
model_type: str,
36+
) -> Tensor:
37+
"""Generate attention mask for VLM models.
38+
39+
This is the generic VLM mask dispatcher. It routes to model-specific
40+
mask generators registered in VlmMaskGeneratorRegistry.
41+
42+
Args:
43+
token_info: Model-specific token information tensor [total_tokens].
44+
Interpretation depends on the model (e.g., token_type_ids for Gemma3
45+
where 1 = image token).
46+
qo_indptr: Tensor [num_contexts + 1] from attention metadata.
47+
Defines the boundaries of each context sequence in the flattened stream.
48+
seq_len: Tensor [num_seqs] with sequence lengths.
49+
Used to distinguish context (seq_len > 1) from generation (seq_len == 1).
50+
sliding_window: Sliding window size. -1 or 0 = no sliding window.
51+
Backend may ignore this if it handles sliding window natively.
52+
model_type: Model type string for registry lookup (e.g., "gemma3").
53+
54+
Returns:
55+
custom_mask: Flattened bool mask for attention layers.
56+
"""
57+
# Dispatch to model-specific generator
58+
generator = VlmMaskGeneratorRegistry.get(model_type)
59+
if generator is None:
60+
# No model-specific generator - return empty mask (no custom masking)
61+
return torch.empty(0, dtype=torch.bool, device=token_info.device)
62+
63+
return generator(token_info, qo_indptr, seq_len, sliding_window)
64+
65+
66+
@create_attention_mask.register_fake
67+
def _create_attention_mask_fake(
68+
token_info: Tensor,
69+
qo_indptr: Tensor,
70+
seq_len: Tensor,
71+
sliding_window: int,
72+
model_type: str,
73+
) -> Tensor:
74+
"""Fake implementation for tracing - returns tensor with correct dtype.
75+
76+
Note: The exact size depends on runtime values (num_contexts, seq_lens),
77+
so we return a conservatively sized tensor for tracing purposes.
78+
"""
79+
device = token_info.device
80+
81+
# Count context sequences
82+
num_contexts = (seq_len > 1).sum()
83+
84+
# Upper bound estimate: sum of squares of context sequence lengths
85+
# In practice, this is an overestimate but safe for tracing
86+
if num_contexts > 0:
87+
total_ctx_tokens = qo_indptr[num_contexts]
88+
# Conservative upper bound: total_tokens^2
89+
max_size = total_ctx_tokens * total_ctx_tokens
90+
else:
91+
max_size = 0
92+
93+
return torch.empty((max_size,), dtype=torch.bool, device=device)
94+
95+
96+
# =============================================================================
97+
# Gemma3-specific mask generation
98+
# =============================================================================
99+
100+
101+
def _get_context_mask_with_bidir_images(image_token_mask: Tensor) -> Tensor:
102+
"""Generate attention mask for a single context sequence (Gemma3 style).
103+
104+
Args:
105+
image_token_mask: Boolean tensor of shape [seq_len] where True = image token.
106+
107+
Returns:
108+
Boolean mask of shape [seq_len, seq_len] where True = attention allowed.
109+
The mask is causal (lower triangular) with bidirectional override for
110+
image-image token pairs.
111+
"""
112+
seq_len = image_token_mask.shape[0]
113+
device = image_token_mask.device
114+
115+
# Base causal mask: lower triangular (query can attend to key if key_pos <= query_pos)
116+
mask = torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device))
117+
118+
# Image-image bidirectional: if both query and key are image tokens, allow attention
119+
is_image_q = image_token_mask.unsqueeze(1) # [seq_len, 1]
120+
is_image_k = image_token_mask.unsqueeze(0) # [1, seq_len]
121+
bidir_image = is_image_q & is_image_k # [seq_len, seq_len]
122+
123+
# Override causal restriction for image-image pairs
124+
mask = mask | bidir_image
125+
126+
return mask
127+
128+
129+
def _gemma3_mask_impl(
130+
token_info: Tensor,
131+
qo_indptr: Tensor,
132+
seq_len: Tensor,
133+
sliding_window: int,
134+
) -> Tensor:
135+
"""Gemma3-specific mask generation implementation.
136+
137+
Creates causal mask with bidirectional attention for image tokens.
138+
Sliding window is ignored - FlashInfer handles it via window_left.
139+
"""
140+
device = token_info.device
141+
142+
# Identify context sequences (seq_len > 1)
143+
num_contexts = (seq_len > 1).sum().item()
144+
145+
if num_contexts == 0:
146+
return torch.empty(0, dtype=torch.bool, device=device)
147+
148+
masks: List[Tensor] = []
149+
qo_indptr_ctx = qo_indptr[: num_contexts + 1]
150+
151+
for i in range(num_contexts):
152+
start = qo_indptr_ctx[i].item()
153+
end = qo_indptr_ctx[i + 1].item()
154+
155+
# Extract image token mask for this sequence
156+
token_info_i = token_info[start:end]
157+
158+
# Generate Gemma3-style mask (causal + bidirectional for images)
159+
mask_i = _get_context_mask_with_bidir_images(token_info_i)
160+
161+
masks.append(mask_i.flatten())
162+
163+
return torch.cat(masks).contiguous()
164+
165+
166+
@VlmMaskGeneratorRegistry.register("gemma3")
167+
def generate_gemma3_vlm_mask(
168+
image_token_mask: Tensor,
169+
qo_indptr: Tensor,
170+
seq_len: Tensor,
171+
sliding_window: int,
172+
) -> Tensor:
173+
"""Generate attention mask for Gemma3 VLM.
174+
175+
For Gemma3:
176+
- token_info is boolean where True = image token
177+
- Image tokens get bidirectional attention to each other
178+
- Text tokens have standard causal attention
179+
- sliding_window is handled by FlashInfer's window_left (ignored here)
180+
181+
Args:
182+
image_token_mask: Boolean tensor [total_tokens] where True = image token.
183+
qo_indptr: Tensor [num_contexts + 1] from attention metadata.
184+
seq_len: Tensor [num_seqs] with sequence lengths.
185+
sliding_window: Sliding window size (ignored by FlashInfer backend).
186+
187+
Returns:
188+
custom_mask: Flattened bool mask for attention layers.
189+
"""
190+
return _gemma3_mask_impl(image_token_mask, qo_indptr, seq_len, sliding_window)

0 commit comments

Comments
 (0)