Skip to content

Commit 6a1a02e

Browse files
committed
add register for different sdpa
1 parent e59a529 commit 6a1a02e

File tree

2 files changed

+147
-145
lines changed

2 files changed

+147
-145
lines changed

tools/llm/run_llm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ def get_model(args):
5858
.eval()
5959
.cuda()
6060
)
61-
register_sdpa.register_sdpa_pass_with_model_config(model_config=model.config)
61+
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
62+
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
63+
else:
64+
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)
6265

6366
if args.precision == "FP16":
6467
model = model.to(torch.float16)
Lines changed: 143 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import copy
22
import logging
33
import operator
4-
from typing import Callable, Sequence, Tuple
4+
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type
55

66
import torch
77
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -13,7 +13,7 @@
1313
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
1414
clean_up_graph_after_modifications,
1515
)
16-
from transformers import Gemma3TextConfig
16+
from transformers import AutoConfig, Gemma3TextConfig
1717

1818
from .sdpa_converter import *
1919

@@ -34,52 +34,130 @@
3434
torch.ops.aten._scaled_dot_product_flash_attention.default,
3535
}
3636

37+
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
38+
get_lowering_pass_config,
39+
)
40+
3741

38-
def register_sdpa_pass_with_model_config(index: int = 0, model_config=None):
39-
"""
40-
Register the SDPA replacement pass with a specific model configuration.
42+
def _process_sdpa_node(
43+
gm: torch.fx.GraphModule,
44+
node: torch.fx.Node,
45+
settings: CompilationSettings,
46+
sliding_window_size: Optional[int] = None,
47+
use_gqa: bool = False,
48+
) -> torch.fx.GraphModule:
49+
"""Helper function to process SDPA nodes with common logic."""
50+
51+
if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default:
52+
if len(node.args) == 7:
53+
(
54+
query,
55+
key,
56+
value,
57+
attn_mask,
58+
compute_log_sumexp,
59+
dropout_p,
60+
is_causal,
61+
) = node.args
62+
elif len(node.args) == 5:
63+
query, key, value, attn_mask, is_causal = node.args
64+
dropout_p = 0.0
65+
else:
66+
raise ValueError(
67+
f"Unexpected number of arguments for {node.target} in the graph"
68+
)
69+
elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default:
70+
if len(node.args) == 6:
71+
(
72+
query,
73+
key,
74+
value,
75+
dropout_p,
76+
is_causal,
77+
return_debug_mask,
78+
) = node.args
79+
elif len(node.args) == 5:
80+
query, key, value, dropout_p, is_causal = node.args
81+
elif len(node.args) == 3:
82+
query, key, value = node.args
83+
dropout_p = 0.0
84+
is_causal = True
85+
else:
86+
raise ValueError(
87+
f"Unexpected number of arguments for {node.target} in the graph"
88+
)
89+
else:
90+
return gm
4191

42-
Args:
43-
model_config: The model configuration object (e.g., from transformers.AutoConfig)
44-
index: Position in the lowering pass list (default: 0)
92+
# Always set causal to True and generate attn_mask inside the sdpa operator
93+
attn_mask = None
94+
is_causal = True
95+
dropout_p = 0.0
4596

46-
Example:
47-
from transformers import AutoConfig
48-
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium")
49-
register_sdpa_pass_with_model_config(config)
50-
"""
51-
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
52-
_aten_lowering_pass,
53-
_remove_lowering_pass,
97+
logger.warning(
98+
f"SDPA converter configuration: attn_mask={attn_mask}, dropout_p={dropout_p}, "
99+
f"is_causal={is_causal}, sliding_window_size={sliding_window_size}, use_gqa={use_gqa}"
54100
)
55101

56-
# Create a new pass with the model configuration
57-
@_aten_lowering_pass(index=index, model_config=model_config)
58-
def replace_variants_of_sdpa_with_config(
59-
gm: torch.fx.GraphModule, settings: CompilationSettings
60-
) -> torch.fx.GraphModule:
61-
"""Replace scaled_dot_product_attention with model-specific configuration"""
102+
modified_input_args = (
103+
query,
104+
key,
105+
value,
106+
attn_mask,
107+
dropout_p,
108+
is_causal,
109+
)
62110

63-
# Access the model configuration from the decorator parameters
64-
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
65-
get_lowering_pass_config,
111+
# Create a new node with torch.nn.functional.scaled_dot_product_attention
112+
with gm.graph.inserting_after(node):
113+
new_node = gm.graph.call_function(
114+
torch.nn.functional.scaled_dot_product_attention,
115+
args=modified_input_args,
116+
kwargs={
117+
"scale": node.kwargs.get("scale", None),
118+
"use_fp32_acc": settings.use_fp32_acc,
119+
"sliding_window_size": sliding_window_size,
120+
},
66121
)
67122

68-
config = get_lowering_pass_config(replace_variants_of_sdpa_with_config)
123+
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
124+
new_node.meta = copy.copy(node.meta)
125+
# Check if there's a getitem node following this attention node
126+
for user in list(node.users):
127+
if user.op == "call_function" and user.target == operator.getitem:
128+
# If the getitem is extracting the first element (the output tensor)
129+
if user.args[1] == 0:
130+
# Replace all uses of the getitem with the new attention node
131+
user.replace_all_uses_with(new_node)
132+
new_node.meta["val"] = new_node.meta["val"][0]
133+
# Replace all uses of the original node with the new node
134+
node.replace_all_uses_with(new_node)
69135

70-
model_config = config.get("model_config", None)
71-
layer_types = []
136+
gm.graph.erase_node(node)
137+
return gm
138+
139+
140+
def register_gemma3_sdpa_pass(index: int = 0, model_config: Any = None) -> None:
141+
@_aten_lowering_pass(index=index, model_config=model_config)
142+
def gemma3_sdpa_pass(
143+
gm: torch.fx.GraphModule, settings: CompilationSettings
144+
) -> torch.fx.GraphModule:
145+
"""SDPA pass specifically for Gemma3 models with sliding window attention."""
146+
config = get_lowering_pass_config(gemma3_sdpa_pass)
72147
sliding_window = None
73-
# Extract model-specific parameters
74-
if model_config is not None:
75-
if isinstance(model_config, Gemma3TextConfig):
76-
sliding_window = getattr(model_config, "sliding_window", None)
77-
layer_types = getattr(model_config, "layer_types", None)
78-
logger.info(f"Model config: {sliding_window=} {layer_types=}")
79-
else:
148+
layer_types = None
149+
model_config = config.get("model_config", None)
150+
if not isinstance(model_config, Gemma3TextConfig):
80151
logger.warning(
81-
"No model configuration provided, using default SDPA replacement behavior"
152+
f"Expected Gemma3TextConfig, got {type(model_config)}, will use default SDPA replacement instead"
153+
)
154+
else:
155+
sliding_window = getattr(model_config, "sliding_window", None)
156+
layer_types = getattr(model_config, "layer_types", None)
157+
logger.debug(
158+
f"got Gemma3 config: sliding_window={sliding_window}, layer_types={layer_types}"
82159
)
160+
83161
index = 0
84162
for node in gm.graph.nodes:
85163
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
@@ -94,116 +172,37 @@ def replace_variants_of_sdpa_with_config(
94172
sliding_window_size = sliding_window
95173
index += 1
96174

97-
if (
98-
node.target
99-
== torch.ops.aten._scaled_dot_product_efficient_attention.default
100-
):
101-
if len(node.args) == 7:
102-
(
103-
query,
104-
key,
105-
value,
106-
attn_mask,
107-
compute_log_sumexp,
108-
dropout_p,
109-
is_causal,
110-
) = node.args
111-
elif len(node.args) == 5:
112-
query, key, value, attn_mask, is_causal = node.args
113-
dropout_p = 0.0
114-
115-
else:
116-
raise ValueError(
117-
f"Unexpected number of arguments for {node.target} in the graph"
118-
)
119-
elif (
120-
node.target
121-
== torch.ops.aten._scaled_dot_product_flash_attention.default
122-
):
123-
if len(node.args) == 6:
124-
(
125-
query,
126-
key,
127-
value,
128-
dropout_p,
129-
is_causal,
130-
return_debug_mask,
131-
) = node.args
132-
if len(node.args) == 5:
133-
query, key, value, dropout_p, is_causal = node.args
134-
elif len(node.args) == 3:
135-
query, key, value = node.args
136-
dropout_p = 0.0
137-
is_causal = True
138-
else:
139-
raise ValueError(
140-
f"Unexpected number of arguments for {node.target} in the graph"
141-
)
142-
143-
# always set_causal to True and generate attn_mask inside the sdpa operator, do not use the attn_mask from the transformers.
144-
attn_mask = None
145-
is_causal = True
146-
dropout_p = 0.0
147-
148-
logger.warning(
149-
f"This current version of SDPA converter only supports {attn_mask=}, {dropout_p=} and {is_causal=} and {sliding_window_size=} configuration. This could cause issues with accuracy for models with different configurations."
150-
)
151-
modified_input_args = (
152-
query,
153-
key,
154-
value,
155-
attn_mask,
156-
dropout_p,
157-
is_causal,
175+
# Process the node
176+
logger.debug(
177+
f"Applying Gemma3-specific SDPA replacement with {node.name=}, {node.target=}, {sliding_window_size=}"
158178
)
159-
# Create a new node with torch.nn.functional.scaled_dot_product_attention
160-
# The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale
161-
with gm.graph.inserting_after(node):
162-
new_node = gm.graph.call_function(
163-
torch.nn.functional.scaled_dot_product_attention,
164-
args=modified_input_args,
165-
kwargs={
166-
"scale": node.kwargs.get("scale", None),
167-
"use_fp32_acc": settings.use_fp32_acc,
168-
"sliding_window_size": sliding_window_size,
169-
},
170-
)
171-
172-
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
173-
new_node.meta = copy.copy(node.meta)
174-
# Check if there's a getitem node following this attention node
175-
for user in list(node.users):
176-
if (
177-
user.op == "call_function"
178-
and user.target == operator.getitem
179-
):
180-
# If the getitem is extracting the first element (the output tensor)
181-
if user.args[1] == 0:
182-
# Replace all uses of the getitem with the new attention node
183-
user.replace_all_uses_with(new_node)
184-
new_node.meta["val"] = new_node.meta["val"][0]
185-
# Replace all uses of the original node with the new node
186-
node.replace_all_uses_with(new_node)
187-
188-
gm.graph.erase_node(node)
189-
190-
# Clean up the graph
179+
gm = _process_sdpa_node(gm, node, settings, sliding_window_size)
180+
191181
clean_up_graph_after_modifications(gm)
182+
logger.debug("Applied Gemma3-specific SDPA replacement")
183+
return gm
192184

193-
if model_config:
194-
logger.debug(
195-
f"Replaced variants of scaled_dot_product_attention for {getattr(model_config, 'model_type', 'unknown')} model"
196-
)
197-
else:
198-
logger.debug(
199-
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
200-
)
201-
add_attn_mask_as_output = False
202-
if add_attn_mask_as_output:
203-
add_one_attn_mask_as_output(gm)
185+
186+
def register_default_sdpa_pass(index: int = 0, model_config: Any = None) -> None:
187+
@_aten_lowering_pass(index=index, model_config=model_config)
188+
def default_sdpa_pass(
189+
gm: torch.fx.GraphModule,
190+
settings: CompilationSettings,
191+
) -> torch.fx.GraphModule:
192+
"""Default SDPA pass for models without specific implementations."""
193+
194+
for node in gm.graph.nodes:
195+
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
196+
# Process the node with default logic
197+
gm = _process_sdpa_node(gm, node, settings)
198+
199+
clean_up_graph_after_modifications(gm)
200+
logger.debug("Applied default SDPA replacement")
204201
return gm
205202

206-
logger.info(
207-
f"Registered SDPA pass with model config: {getattr(model_config, 'model_type', 'unknown')}"
208-
)
209-
return replace_variants_of_sdpa_with_config
203+
204+
# Global registry for SDPA passes
205+
_SDPA_MAPPING: Dict[str, Callable] = {
206+
"google/gemma-3-1b-it": register_gemma3_sdpa_pass,
207+
"default": register_default_sdpa_pass,
208+
}

0 commit comments

Comments
 (0)