Skip to content

Commit df120f0

Browse files
committed
generic attn mask op
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
1 parent ad81dff commit df120f0

File tree

10 files changed

+84
-33
lines changed

10 files changed

+84
-33
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def flashinfer_mha_with_cache(
224224
q: torch.Tensor,
225225
k: torch.Tensor,
226226
v: torch.Tensor,
227+
custom_mask: Optional[torch.Tensor],
227228
# STANDARD METADATA
228229
batch_info: torch.Tensor,
229230
cu_seqlen: torch.Tensor,
@@ -244,10 +245,6 @@ def flashinfer_mha_with_cache(
244245
v_scale: float,
245246
window_left: int, # FlashInfer inclusive sliding window (use -1 to disable)
246247
logits_soft_cap: float, # FlashInfer logits softcap (use 0.0 to disable)
247-
# VLM CUSTOM MASK (optional, for Gemma3 etc.)
248-
# Contains bidirectional attention for image tokens. Sliding window is
249-
# handled separately by the window_left parameter.
250-
custom_mask: Optional[torch.Tensor],
251248
) -> torch.Tensor:
252249
# reshape to standard [b*s, n_heads, head_dim] layout
253250
head_dim = k_cache.shape[-1]

tensorrt_llm/_torch/auto_deploy/custom_ops/vlm_mask_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ def create_attention_mask(
5757
# Dispatch to model-specific generator
5858
generator = VlmMaskGeneratorRegistry.get(model_type)
5959
if generator is None:
60-
# No model-specific generator - return empty mask (no custom masking)
61-
return torch.empty(0, dtype=torch.bool, device=token_info.device)
60+
raise ValueError(
61+
f"No model-specific generator found for model type: {model_type}. \
62+
Registered model types: {VlmMaskGeneratorRegistry.registered_model_types()}."
63+
)
6264

6365
return generator(token_info, qo_indptr, seq_len, sliding_window)
6466

@@ -163,7 +165,7 @@ def _gemma3_mask_impl(
163165
return torch.cat(masks).contiguous()
164166

165167

166-
@VlmMaskGeneratorRegistry.register("gemma3")
168+
@VlmMaskGeneratorRegistry.register("gemma3_text")
167169
def generate_gemma3_vlm_mask(
168170
image_token_mask: Tensor,
169171
qo_indptr: Tensor,

tensorrt_llm/_torch/auto_deploy/custom_ops/vlm_mask_registry.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,3 +128,8 @@ def has(cls, model_type: str) -> bool:
128128
True if a generator is registered, False otherwise.
129129
"""
130130
return model_type in cls._registry
131+
132+
@classmethod
133+
def registered_model_types(cls) -> list:
134+
"""Return a list of all registered model types."""
135+
return list(cls._registry.keys())

tensorrt_llm/_torch/auto_deploy/export/export.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def run_forward_for_capture(
227227
*,
228228
patch_configs: Optional[Dict[str, Union[dict, Any]]] = None,
229229
patch_list: Optional[List[str]] = None,
230+
post_export_callback: Optional[Callable[[nn.Module], None]] = None,
230231
) -> nn.Module:
231232
"""A wrapper to run the provided closure over the model on the meta device with patches.
232233
@@ -244,6 +245,8 @@ def run_forward_for_capture(
244245
patch_configs: Optional patch configurations. If None, all registered patches
245246
will be applied with default settings.
246247
patch_list: Optional list of patch names to apply with default settings.
248+
post_export_callback: Optional callback called after capture but before patches are reverted.
249+
Receives the captured module as argument.
247250
"""
248251
# run capture with patches and lifted to meta
249252
with apply_export_patches(patch_configs, patch_list), lift_to_meta(model) as state_dict:
@@ -259,6 +262,10 @@ def run_forward_for_capture(
259262
else:
260263
mod_after_capture = capture_fn(model, args, kwargs)
261264

265+
# Call post_export_callback while patches are still active
266+
if post_export_callback is not None:
267+
post_export_callback(mod_after_capture)
268+
262269
# load state_dict into egm
263270
# NOTE: export might have removed unused params/buffers (hence we allow unexpected keys)
264271
if mod_after_capture is not model:
@@ -283,6 +290,7 @@ def torch_export_to_gm(
283290
strict: bool = False,
284291
patch_configs: Optional[Dict[str, Union[dict, Any]]] = None,
285292
patch_list: Optional[List[str]] = None,
293+
post_export_callback: Optional[Callable[[nn.Module], None]] = None,
286294
) -> fx.GraphModule:
287295
"""torch's export with wrapping into GraphModule + useful additions to the resulting module.
288296
@@ -306,6 +314,8 @@ def torch_export_to_gm(
306314
will be applied with default settings.
307315
patch_list: Optional list of patch names to apply with default settings.
308316
Cannot be used together with patch_configs.
317+
post_export_callback: Optional callback called after export but before patches are reverted.
318+
Receives the exported GraphModule as argument.
309319
"""
310320

311321
def _capture_fn(model, args, kwargs):
@@ -316,7 +326,14 @@ def _capture_fn(model, args, kwargs):
316326

317327
# run capture with export
318328
egm = run_forward_for_capture(
319-
model, _capture_fn, args, kwargs, clone, patch_list=patch_list, patch_configs=patch_configs
329+
model,
330+
_capture_fn,
331+
args,
332+
kwargs,
333+
clone,
334+
patch_list=patch_list,
335+
patch_configs=patch_configs,
336+
post_export_callback=post_export_callback,
320337
)
321338

322339
# Export strips away all methods not traced during forward. The model could have

tensorrt_llm/_torch/auto_deploy/models/factory.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,24 @@ def _init_dynamic_shape_lookup(self) -> Dict[str, DynamicShape]:
6767
"""Initialize the lookup for the dynamic shapes of keyword arguments."""
6868
raise NotImplementedError("Subclasses must implement this method.")
6969

70+
def post_export(self, sub_mod: nn.Module, sub_gm: GraphModule):
71+
"""Called after export but BEFORE patches are reverted.
72+
73+
Args:
74+
sub_mod: The submodule from which the graph was captured+exported.
75+
sub_gm: The graph module that was exported.
76+
77+
This method is called while export patches are still active, allowing access to
78+
patch-set metadata on the module (e.g., _vlm_input_names). Override this method
79+
to set metadata on the GraphModule that depends on patch state.
80+
81+
Default implementation does nothing.
82+
"""
83+
pass
84+
7085
@abstractmethod
7186
def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule):
72-
"""Post-process the subgraph module.
87+
"""Post-process the subgraph module AFTER patches are reverted.
7388
7489
Args:
7590
sub_mod: The submodule from which the graph was captured+exported.

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,14 @@ def __call__(self, module, state_dict, *args, **kwargs) -> None:
560560
class TextModelExportInfo(SubModuleExportInfo):
561561
"""An export configuration for the text model portion of a VLM."""
562562

563+
def post_export(self, sub_mod: nn.Module, sub_gm: GraphModule):
564+
"""Called after export but BEFORE patches are reverted.
565+
566+
Sets VLM metadata on the GraphModule while patches are still active,
567+
so we can read _vlm_input_names from the module class.
568+
"""
569+
self._set_vlm_metadata(sub_mod, sub_gm)
570+
563571
def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule):
564572
"""Post-process the subgraph module and make sure the embedding remains available."""
565573
# make sure get_input_embeddings function is available in the graph module
@@ -588,10 +596,6 @@ def post_process(self, sub_mod: nn.Module, sub_gm: GraphModule):
588596
torch._assert, args=(n_embed_tokens, "Avoid embedding getting deleted from graph.")
589597
)
590598

591-
# Set VLM metadata on the GraphModule for runtime use.
592-
# This is read by ADExecutor to determine which inputs to inject from multimodal_data.
593-
self._set_vlm_metadata(sub_mod, sub_gm)
594-
595599
def _set_vlm_metadata(self, sub_mod: nn.Module, sub_gm: GraphModule):
596600
"""Set VLM-related metadata on the GraphModule.
597601

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,6 @@ def _find_vlm_graphmodule(mod: torch.nn.Module) -> Optional[torch.nn.Module]:
742742
# Store on engine for external access
743743
engine._vlm_inputs = vlm_inputs
744744
engine._vlm_model_type = vlm_model_type
745-
746745
# Detect if the model is a VLM that expects custom masks
747746
# This is relevant for FlashInfer backend with VLM models
748747
engine._expects_vlm_custom_masks = (

tensorrt_llm/_torch/auto_deploy/transform/library/export_to_gm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,11 @@ def _is_child(child: str, parent: str) -> bool:
192192
# torch.export can get confused by keyword arguments that are not explicitly defined in
193193
# the signature but are captured through generic **kwargs. By overwriting the signature,
194194
# we ensure each argument is explicitly defined in the signature.
195+
196+
# Create callback to call post_export while patches are still active
197+
def _post_export_cb(exported_gm):
198+
e_info.post_export(sub_mod, exported_gm)
199+
195200
with set_exact_signature(sub_mod, captured_kwargs):
196201
sub_gm = torch_export_to_gm(
197202
sub_mod,
@@ -201,6 +206,7 @@ def _is_child(child: str, parent: str) -> bool:
201206
clone=self.config.clone_state_dict,
202207
strict=self.config.strict,
203208
patch_list=self.config.patch_list,
209+
post_export_callback=_post_export_cb,
204210
)
205211

206212
# Ensure runtime calls from HF into this exported GraphModule do not fail due to

tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,18 @@ def _maybe_init_vlm_custom_mask(
184184

185185
self._vlm_custom_mask_node = custom_mask_node
186186

187-
def _maybe_append_flashinfer_vlm_custom_mask(self, cached_attn_op, args: Tuple) -> Tuple:
188-
"""Append FlashInfer VLM custom mask arg (or None) to `args`.
187+
def _get_vlm_custom_mask_node(self, cached_attn_op) -> Optional[Node]:
188+
"""Get the VLM custom mask node for FlashInfer attention ops.
189189
190190
All layers receive the same mask - it provides bidirectional attention
191191
for image tokens. Sliding window is handled separately by window_left.
192+
193+
Returns:
194+
The custom mask node, or None if not a FlashInfer op or no VLM.
192195
"""
193196
if not self._is_flashinfer_cached_attn_op(cached_attn_op):
194-
return args
195-
196-
# Append the custom mask node (or None if no VLM)
197-
custom_mask = getattr(self, "_vlm_custom_mask_node", None)
198-
return (*args, custom_mask)
197+
return None
198+
return getattr(self, "_vlm_custom_mask_node", None)
199199

200200
def _process_metadata_extra(
201201
self, gm: GraphModule, cm: CachedSequenceInterface, any_source_attn_node: Node
@@ -239,18 +239,16 @@ def _insert_cached_attn_node(
239239
"""Insert a cached attention node into the graph."""
240240
with gm.graph.inserting_before(attn_node):
241241
cached_attn_op = self.attn_descriptor.get_cached_attention_op()
242+
custom_mask_node = self._get_vlm_custom_mask_node(cached_attn_op)
242243
args = (
243244
*qkv_nodes,
245+
custom_mask_node,
244246
*meta_nodes_std,
245247
*meta_nodes_extra,
246248
*cache_nodes,
247249
*buffer_nodes,
248250
*constants,
249251
)
250-
# FlashInfer cached attention op optionally accepts a custom mask arg for VLM.
251-
# The mask provides bidirectional attention for image tokens. Sliding window
252-
# is handled separately by FlashInfer's window_left parameter.
253-
args = self._maybe_append_flashinfer_vlm_custom_mask(cached_attn_op, args)
254252
cached_attn_node = gm.graph.call_function(
255253
cached_attn_op,
256254
args=args,

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_flashinfer_attention_op.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
9898
q,
9999
k,
100100
v,
101+
# VLM CUSTOM MASK
102+
None, # custom_mask
101103
# STANDARD METADATA
102104
batch_info,
103105
qo_indptr,
@@ -118,7 +120,6 @@ def test_flashinfer_attention_op_context(seq_length, n_heads, batch_size, dtype,
118120
1.0,
119121
-1, # window_left (disabled)
120122
0.0, # logits_soft_cap (disabled)
121-
None, # custom_mask
122123
)
123124

124125
# Use torch backend as clean reference
@@ -236,6 +237,8 @@ def test_flashinfer_attention_op_decode(
236237
q,
237238
k,
238239
v,
240+
# VLM CUSTOM MASK
241+
None, # custom_mask
239242
# STANDARD METADATA
240243
batch_info,
241244
qo_indptr,
@@ -256,7 +259,6 @@ def test_flashinfer_attention_op_decode(
256259
1.0,
257260
-1, # window_left (disabled)
258261
0.0, # logits_soft_cap (disabled)
259-
None, # custom_mask
260262
)
261263

262264
assert torch.allclose(
@@ -364,6 +366,8 @@ def test_flashinfer_attention_context_and_generate(
364366
q_1,
365367
k_1,
366368
v_1,
369+
# VLM CUSTOM MASK
370+
None, # custom_mask
367371
# STANDARD METADATA
368372
batch_info,
369373
qo_indptr,
@@ -384,7 +388,6 @@ def test_flashinfer_attention_context_and_generate(
384388
1.0,
385389
-1, # window_left (disabled)
386390
0.0, # logits_soft_cap (disabled)
387-
None, # custom_mask
388391
)
389392

390393
# Generate reference outputs
@@ -446,6 +449,8 @@ def test_flashinfer_attention_context_and_generate(
446449
q_3,
447450
k_3,
448451
v_3,
452+
# VLM CUSTOM MASK
453+
None, # custom_mask
449454
# STANDARD METADATA
450455
batch_info,
451456
qo_indptr,
@@ -466,7 +471,6 @@ def test_flashinfer_attention_context_and_generate(
466471
1.0,
467472
-1, # window_left (disabled)
468473
0.0, # logits_soft_cap (disabled)
469-
None, # custom_mask
470474
)
471475

472476
# Generate reference outputs
@@ -564,6 +568,8 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
564568
q,
565569
k,
566570
v,
571+
# VLM CUSTOM MASK
572+
None, # custom_mask
567573
# STANDARD METADATA
568574
batch_info,
569575
qo_indptr,
@@ -584,7 +590,6 @@ def test_flashinfer_attention_op_context_input_pos(seq, batch_size, n_heads, dty
584590
1.0,
585591
-1, # window_left (disabled)
586592
0.0, # logits_soft_cap (disabled)
587-
None, # custom_mask
588593
)
589594

590595
# Generate ref
@@ -720,6 +725,8 @@ def test_flashinfer_attention_with_fp8_cache(
720725
q,
721726
k,
722727
v,
728+
# VLM CUSTOM MASK
729+
None, # custom_mask
723730
# STANDARD METADATA
724731
batch_info,
725732
qo_indptr,
@@ -740,7 +747,6 @@ def test_flashinfer_attention_with_fp8_cache(
740747
V_SCALE,
741748
-1, # window_left (disabled)
742749
0.0, # logits_soft_cap (disabled)
743-
None, # custom_mask
744750
)
745751

746752
y = flashinfer_output.view(BATCH_SIZE, SEQ_LEN, N_HEADS, D_HEAD)
@@ -824,6 +830,8 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
824830
q,
825831
k,
826832
v,
833+
# VLM CUSTOM MASK
834+
None, # custom_mask
827835
# STANDARD METADATA
828836
batch_info,
829837
qo_indptr,
@@ -844,7 +852,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
844852
1.0,
845853
-1, # window_left (disabled)
846854
0.0, # logits_soft_cap (disabled)
847-
None, # custom_mask
848855
)
849856

850857
# Compute reference
@@ -914,6 +921,8 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
914921
q_gen,
915922
k_gen,
916923
v_gen,
924+
# VLM CUSTOM MASK
925+
None, # custom_mask
917926
# STANDARD METADATA
918927
batch_info,
919928
qo_indptr2,
@@ -934,7 +943,6 @@ def test_flashinfer_attention_with_paged_kvcache(seq_lengths, n_heads, dtype, de
934943
1.0,
935944
-1, # window_left (disabled)
936945
0.0, # logits_soft_cap (disabled)
937-
None, # custom_mask
938946
)
939947

940948
# Compute reference

0 commit comments

Comments
 (0)