Skip to content

Commit 9636ea6

Browse files
committed
fix: resolving SDPA lowering position issue (output mismatch)
1 parent 3a58d2b commit 9636ea6

File tree

4 files changed

+77
-26
lines changed

4 files changed

+77
-26
lines changed

tools/llm/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,14 @@ This codebase can be extended to
7474

7575
## Limitations
7676
- We do not currently support sliding window attention (used in Gemma3 and Qwen 3 models) yet.
77+
- **Flash Attention Limitation**: Some models (e.g., Eagle2-2B) internally use flash attention operations (`torch.ops.flash_attn._flash_attn_forward.default`) which require the `flash-attn` package to be installed. Without flash-attn, these models will fail to load or run properly.
7778

7879
## Requirements
7980

8081
- Torch-TensorRT 2.8.0
8182
- Transformers v4.52.3
8283
- For VLM models (run_vlm.py):
8384
- `pip install qwen-vl-utils` (for Qwen2.5-VL-3B-Instruct model)
84-
- `pip install flash-attn --no-build-isolation -v` (for Eagle2-2B model)
85+
- **Flash Attention**: For models using flash attention operations (e.g., Eagle2-2B), install one of the following:
86+
- **Fast installation (recommended)**: `pip install flash-attn==2.8.1` (pre-built wheel, should work)
87+
- **Source build (slow)**: `pip install flash-attn --no-build-isolation -v` (fallback if pre-built wheels fail)

tools/llm/run_vlm.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,10 @@
3737
import requests
3838
import torch
3939
import torch_tensorrt
40-
41-
# we "monkey-patch" the global attention function map for Qwen2.
42-
# This ensures that any part of the code (including torch.export) requesting
43-
# "flash_attention_2" will receive the "sdpa" implementation instead.
44-
# This patch is global for the script's execution context.
45-
import transformers.models.qwen2.modeling_qwen2 as mq
4640
from PIL import Image
47-
from torchtrt_ext import register_sdpa
48-
from transformers import AutoConfig, AutoModel, AutoProcessor
41+
from transformers import AutoModel, AutoProcessor
42+
from transformers.models.qwen2 import modeling_qwen2 as mq
43+
from transformers.models.siglip import modeling_siglip as ms
4944
from utils import (
5045
export_llm,
5146
generate_mm,
@@ -59,8 +54,7 @@
5954
# Eagle2's language model (Qwen2) implicitly defaults to "flash_attention_2"
6055
# due to settings in its remote code and config.json. This prevents direct
6156
# compilation with SDPA. To work around this without modifying the library,
62-
63-
57+
ms.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = ms.ALL_ATTENTION_FUNCTIONS["sdpa"]
6458
mq.ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = mq.ALL_ATTENTION_FUNCTIONS["sdpa"]
6559
# --- END WORKAROUND ---
6660

@@ -259,8 +253,6 @@ def _compile_lm(
259253
seq_len = torch.export.Dim("seq", min=1, max=max_seq_len)
260254
position_ids = torch.arange(input_embeds.shape[1]).unsqueeze(0).to(device)
261255

262-
dyn_shapes = {"inputs_embeds": {1: seq_len}, "position_ids": {1: seq_len}}
263-
264256
use_fp32_acc = False
265257
use_explicit_typing = False
266258
if args.precision == "FP16":
@@ -594,6 +586,9 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer):
594586
# -------------------------------------------------------------------------#
595587
# Register static cache lowering passes if requested
596588
# Cache is not applied to vision model.
589+
print("--- Registering SDPA lowering pass locally for LM compilation ---")
590+
from torchtrt_ext import register_sdpa
591+
597592
if args.cache == "static_v1":
598593
import static_cache_v1 # noqa: F401
599594
elif args.cache not in ("", None):

tools/llm/static_cache_v1.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ def add_kv_as_outputs(gm, kv_cache_for_graph: List[Tuple[torch.Tensor, torch.Ten
3939
# Get the current output args (typically a tuple)
4040
current_outputs = output_node.args[0]
4141

42-
# If the current output is a tuple, extend it with our new outputs
43-
if isinstance(current_outputs, tuple):
44-
new_outputs = current_outputs + tuple(kv_cache_for_graph)
45-
else:
46-
# If there's only one output or it's not a tuple, create a new tuple
47-
new_outputs = (current_outputs,) + tuple(kv_cache_for_graph)
42+
# Ensure the original output is always treated as a tuple to avoid ambiguity
43+
if not isinstance(current_outputs, tuple):
44+
current_outputs = (current_outputs,)
45+
46+
# Extend the tuple with our new outputs
47+
new_outputs = current_outputs + tuple(kv_cache_for_graph)
4848

4949
gm.graph.output(new_outputs)
5050
gm.graph.erase_node(output_node)
@@ -98,7 +98,7 @@ def get_static_tensor(tensor: torch.Tensor):
9898
start_idx_input = _add_graph_input(gm, "start_idx", torch.tensor(0))
9999
end_idx_input = _add_graph_input(gm, "end_idx", torch.tensor(1))
100100

101-
# Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, is_causal, key_cache1, value_cache1, key_cache2, value_cache2, ..
101+
# Get the max sequence length from the first key_cache node. The order of nodes is: input_ids, position_ids, key_cache1, value_cache1, ...
102102
input_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"]
103103
input_ids_meta = input_nodes[0].meta["val"]
104104
seq_len = input_ids_meta.shape[1]

tools/llm/utils.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,45 @@ def export_llm(model, inputs, min_seq_len=1, max_seq_len=16):
4747
return ep
4848

4949

50-
def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule, device: str = "cuda:0"):
50+
def export_llm_no_position_ids(model, inputs, min_seq_len=1, max_seq_len=16):
51+
"""
52+
Exports the LLM model into an ExportedProgram with dynamic shapes.
53+
In the case of guard failures due to some PyTorch kernel implements, we also
54+
try to re-export the graph by expressing them as runtime assert nodes
55+
"""
56+
with torch.no_grad():
57+
# max=1024 has contraint violation error. https://github.com/pytorch/pytorch/issues/125604
58+
seq_len = torch.export.Dim("seq_len", min=min_seq_len, max=max_seq_len)
59+
try:
60+
print("Trying to export the model using torch.export.export()..")
61+
# strict=False only enables aotautograd tracing and excludes dynamo.
62+
ep = torch.export.export(
63+
model,
64+
args=(inputs,),
65+
dynamic_shapes=({1: seq_len},),
66+
strict=False,
67+
)
68+
except:
69+
print(
70+
"Trying torch.export._trace._export to trace the graph since torch.export.export() failed"
71+
)
72+
# This API is used to express the constraint violation guards as asserts in the graph.
73+
ep = torch.export._trace._export(
74+
model,
75+
args=(inputs,),
76+
dynamic_shapes=({1: seq_len},),
77+
strict=False,
78+
allow_complex_guards_as_runtime_asserts=True,
79+
)
80+
81+
return ep
82+
83+
84+
def get_zeroed_static_cache_inputs(
85+
model: "torch.fx.GraphModule",
86+
device: str = "cuda:0",
87+
has_position_ids: bool = True,
88+
):
5189
"""
5290
Extracts and returns zeroed static KV cache tensors from a torch.fx.GraphModule. This should only be used for static cache_v1 and static cache_v2.
5391
@@ -56,15 +94,26 @@ def get_zeroed_static_cache_inputs(model: torch.fx.GraphModule, device: str = "c
5694
5795
Args:
5896
model (torch.fx.GraphModule): The exported model graph containing KV cache placeholders
97+
device (str): Device to create the zeroed tensors on.
98+
has_position_ids (bool): Whether position_ids is present as an input. Default: True
5999
60100
Returns:
61101
tuple: A tuple of zeroed tensors corresponding to the KV cache placeholders in the graph
62102
"""
63103
# placeholder nodes are expected to be in the following order:
64-
# input_ids, kv_cache_key, kv_cache_value, start_idx, end_idx
104+
# input_ids, position_ids, kv_cache_key, kv_cache_value, ..., start_idx, end_idx
65105
placeholder_nodes = [node for node in model.graph.nodes if node.op == "placeholder"]
66-
# The first two inputs are input_ids, position_ids. The last two inputs are start_idx, end_idx. In between are the KV cache tensors.
67-
kv_cache_inputs = placeholder_nodes[2:-2]
106+
107+
# By default, assume input_ids and position_ids are present as the first two inputs.
108+
# If has_position_ids is False, only input_ids is present.
109+
if has_position_ids:
110+
kv_start = 2
111+
else:
112+
kv_start = 1
113+
# The last two inputs are start_idx, end_idx.
114+
kv_end = -2
115+
116+
kv_cache_inputs = placeholder_nodes[kv_start:kv_end]
68117
zeroed_kv_cache_inputs = []
69118
for input in kv_cache_inputs:
70119
zeroed_kv_cache_inputs.append(
@@ -458,7 +507,9 @@ def generate_mm_with_static_cache(
458507
)
459508

460509
# ───────────────────── KV-cache initialization ─────────────────────
461-
kv_cache = get_zeroed_static_cache_inputs(model.language_model, device=device)
510+
kv_cache = get_zeroed_static_cache_inputs(
511+
model.language_model, device=device, has_position_ids=True
512+
)
462513
start_idx = 0
463514
end_idx = seq_embeds.size(1)
464515
generated = 0
@@ -710,7 +761,9 @@ def generate_mm_qwen2_5_vl_with_static_cache(
710761
with_timing=False,
711762
)
712763

713-
kv_cache = get_zeroed_static_cache_inputs(model.model, device=device)
764+
kv_cache = get_zeroed_static_cache_inputs(
765+
model.model, device=device, has_position_ids=True
766+
)
714767
start_idx = 0
715768
end_idx = seq_embeds.size(1)
716769
generated = 0

0 commit comments

Comments
 (0)