Skip to content

Commit f9b66ad

Browse files
committed
register lowering pass with model config
1 parent 51a60a5 commit f9b66ad

File tree

5 files changed

+195
-135
lines changed

5 files changed

+195
-135
lines changed

py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Callable, Optional, Sequence, Union
2+
from typing import Any, Callable, Optional, Sequence, Union
33

44
import torch
55
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -55,20 +55,28 @@
5555
def _aten_lowering_pass(
5656
*args: LoweringPassSignature,
5757
index: Optional[int] = None,
58+
**kwargs: Any,
5859
) -> Union[
5960
LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature]
6061
]:
6162
"""Adds a lowering pass to the registry, at a specified index if desired
6263
6364
If no index is specified, the lowering pass is inserted at the end of the list
65+
66+
Additional keyword arguments can be passed to configure the lowering pass behavior.
67+
These will be stored as metadata on the pass function.
6468
"""
6569

6670
def add_lowering_pass(
6771
lowering_pass: LoweringPassSignature,
6872
) -> LoweringPassSignature:
73+
# Store additional parameters as metadata on the function
74+
if kwargs:
75+
lowering_pass._lowering_pass_config = kwargs
76+
6977
ATEN_POST_LOWERING_PASSES.add_pass_with_index(lowering_pass, index)
7078
logger.debug(
71-
f"Added lowering pass {lowering_pass} to list at index {index}, current passlist: {ATEN_POST_LOWERING_PASSES}"
79+
f"Added lowering pass {lowering_pass} to list at index {index} with config {kwargs}, current passlist: {ATEN_POST_LOWERING_PASSES}"
7280
)
7381
return lowering_pass
7482

@@ -83,7 +91,7 @@ def add_lowering_pass(
8391
f"aten_lowering_pass decorator called with invalid arguments {args} "
8492
"To specify an index to insert the pass, use the keyword 'index='"
8593
)
86-
# If no arguments are specified, the decorator was called with an index keyword
94+
# If no arguments are specified, the decorator was called with keyword arguments
8795
else:
8896
return add_lowering_pass
8997

@@ -97,6 +105,18 @@ def _remove_lowering_pass(*, index: int) -> None:
97105
return
98106

99107

108+
def get_lowering_pass_config(lowering_pass: LoweringPassSignature) -> dict[str, Any]:
109+
"""Get the configuration parameters for a lowering pass function
110+
111+
Args:
112+
lowering_pass: The lowering pass function
113+
114+
Returns:
115+
Dictionary containing the configuration parameters, or empty dict if none
116+
"""
117+
return getattr(lowering_pass, "_lowering_pass_config", {})
118+
119+
100120
def post_lowering(
101121
gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings()
102122
) -> torch.fx.GraphModule:

tools/llm/run_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def get_model(args):
5858
.eval()
5959
.cuda()
6060
)
61+
register_sdpa.register_sdpa_pass_with_model_config(model_config=model.config)
6162

6263
if args.precision == "FP16":
6364
model = model.to(torch.float16)

tools/llm/torchtrt_ext/register_sdpa.py

Lines changed: 165 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
1414
clean_up_graph_after_modifications,
1515
)
16+
from transformers import Gemma3TextConfig
1617

1718
from .sdpa_converter import *
1819

@@ -34,134 +35,175 @@
3435
}
3536

3637

37-
@_aten_lowering_pass
38-
def replace_variants_of_sdpa(
39-
gm: torch.fx.GraphModule, settings: CompilationSettings
40-
) -> torch.fx.GraphModule:
41-
"""Replace scaled_dot_product_attention with an equivalent
42-
implementation which can be accurately converted to TRT
38+
def register_sdpa_pass_with_model_config(index: int = 0, model_config=None):
4339
"""
40+
Register the SDPA replacement pass with a specific model configuration.
4441
45-
for node in gm.graph.nodes:
46-
attn_mask = None
47-
is_causal = False
48-
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
49-
if (
50-
node.target
51-
== torch.ops.aten._scaled_dot_product_efficient_attention.default
52-
):
53-
if len(node.args) == 7:
54-
(
55-
query,
56-
key,
57-
value,
58-
attn_mask,
59-
compute_log_sumexp,
60-
dropout_p,
61-
is_causal,
62-
) = node.args
63-
elif len(node.args) == 5:
64-
query, key, value, attn_mask, is_causal = node.args
65-
dropout_p = 0.0
66-
67-
else:
68-
raise ValueError(
69-
f"Unexpected number of arguments for {node.target} in the graph"
70-
)
71-
elif (
72-
node.target
73-
== torch.ops.aten._scaled_dot_product_flash_attention.default
74-
):
75-
if len(node.args) == 6:
76-
(
77-
query,
78-
key,
79-
value,
80-
dropout_p,
81-
is_causal,
82-
return_debug_mask,
83-
) = node.args
84-
if len(node.args) == 5:
85-
query, key, value, dropout_p, is_causal = node.args
86-
elif len(node.args) == 3:
87-
query, key, value = node.args
88-
dropout_p = 0.0
89-
is_causal = True
90-
else:
91-
raise ValueError(
92-
f"Unexpected number of arguments for {node.target} in the graph"
93-
)
42+
Args:
43+
model_config: The model configuration object (e.g., from transformers.AutoConfig)
44+
index: Position in the lowering pass list (default: 0)
45+
46+
Example:
47+
from transformers import AutoConfig
48+
config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium")
49+
register_sdpa_pass_with_model_config(config)
50+
"""
51+
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
52+
_aten_lowering_pass,
53+
_remove_lowering_pass,
54+
)
9455

56+
# Create a new pass with the model configuration
57+
@_aten_lowering_pass(index=index, model_config=model_config)
58+
def replace_variants_of_sdpa_with_config(
59+
gm: torch.fx.GraphModule, settings: CompilationSettings
60+
) -> torch.fx.GraphModule:
61+
"""Replace scaled_dot_product_attention with model-specific configuration"""
62+
63+
# Access the model configuration from the decorator parameters
64+
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
65+
get_lowering_pass_config,
66+
)
67+
68+
config = get_lowering_pass_config(replace_variants_of_sdpa_with_config)
69+
70+
model_config = config.get("model_config", None)
71+
layer_types = []
72+
sliding_window = None
73+
# Extract model-specific parameters
74+
if model_config is not None:
75+
if isinstance(model_config, Gemma3TextConfig):
76+
sliding_window = getattr(model_config, "sliding_window", None)
77+
layer_types = getattr(model_config, "layer_types", None)
78+
logger.info(f"Model config: {sliding_window=} {layer_types=}")
79+
else:
9580
logger.warning(
96-
f"This current version of SDPA converter only supports attn_mask = {attn_mask}, dropout_p = {dropout_p} and is_causal = {is_causal} configuration. This could cause issues with accuracy for models with different configurations."
81+
"No model configuration provided, using default SDPA replacement behavior"
9782
)
98-
modified_input_args = (query, key, value, attn_mask, dropout_p, is_causal)
99-
# Create a new node with torch.nn.functional.scaled_dot_product_attention
100-
# The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale
101-
with gm.graph.inserting_after(node):
102-
new_node = gm.graph.call_function(
103-
torch.nn.functional.scaled_dot_product_attention,
104-
args=modified_input_args,
105-
kwargs={
106-
"scale": node.kwargs.get("scale", None),
107-
"use_fp32_acc": settings.use_fp32_acc,
108-
},
83+
index = 0
84+
for node in gm.graph.nodes:
85+
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
86+
sliding_window_size = None
87+
if (
88+
sliding_window is not None
89+
and sliding_window > 0
90+
and layer_types is not None
91+
and index < len(layer_types)
92+
):
93+
if layer_types[index] == "sliding_attention":
94+
sliding_window_size = sliding_window
95+
index += 1
96+
97+
if (
98+
node.target
99+
== torch.ops.aten._scaled_dot_product_efficient_attention.default
100+
):
101+
if len(node.args) == 7:
102+
(
103+
query,
104+
key,
105+
value,
106+
attn_mask,
107+
compute_log_sumexp,
108+
dropout_p,
109+
is_causal,
110+
) = node.args
111+
elif len(node.args) == 5:
112+
query, key, value, attn_mask, is_causal = node.args
113+
dropout_p = 0.0
114+
115+
else:
116+
raise ValueError(
117+
f"Unexpected number of arguments for {node.target} in the graph"
118+
)
119+
elif (
120+
node.target
121+
== torch.ops.aten._scaled_dot_product_flash_attention.default
122+
):
123+
if len(node.args) == 6:
124+
(
125+
query,
126+
key,
127+
value,
128+
dropout_p,
129+
is_causal,
130+
return_debug_mask,
131+
) = node.args
132+
if len(node.args) == 5:
133+
query, key, value, dropout_p, is_causal = node.args
134+
elif len(node.args) == 3:
135+
query, key, value = node.args
136+
dropout_p = 0.0
137+
is_causal = True
138+
else:
139+
raise ValueError(
140+
f"Unexpected number of arguments for {node.target} in the graph"
141+
)
142+
143+
# always set_causal to True and generate attn_mask inside the sdpa operator, do not use the attn_mask from the transformers.
144+
attn_mask = None
145+
is_causal = True
146+
dropout_p = 0.0
147+
148+
logger.warning(
149+
f"This current version of SDPA converter only supports {attn_mask=}, {dropout_p=} and {is_causal=} and {sliding_window_size=} configuration. This could cause issues with accuracy for models with different configurations."
109150
)
151+
modified_input_args = (
152+
query,
153+
key,
154+
value,
155+
attn_mask,
156+
dropout_p,
157+
is_causal,
158+
)
159+
# Create a new node with torch.nn.functional.scaled_dot_product_attention
160+
# The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale
161+
with gm.graph.inserting_after(node):
162+
new_node = gm.graph.call_function(
163+
torch.nn.functional.scaled_dot_product_attention,
164+
args=modified_input_args,
165+
kwargs={
166+
"scale": node.kwargs.get("scale", None),
167+
"use_fp32_acc": settings.use_fp32_acc,
168+
"sliding_window_size": sliding_window_size,
169+
},
170+
)
171+
172+
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
173+
new_node.meta = copy.copy(node.meta)
174+
# Check if there's a getitem node following this attention node
175+
for user in list(node.users):
176+
if (
177+
user.op == "call_function"
178+
and user.target == operator.getitem
179+
):
180+
# If the getitem is extracting the first element (the output tensor)
181+
if user.args[1] == 0:
182+
# Replace all uses of the getitem with the new attention node
183+
user.replace_all_uses_with(new_node)
184+
new_node.meta["val"] = new_node.meta["val"][0]
185+
# Replace all uses of the original node with the new node
186+
node.replace_all_uses_with(new_node)
187+
188+
gm.graph.erase_node(node)
189+
190+
# Clean up the graph
191+
clean_up_graph_after_modifications(gm)
192+
193+
if model_config:
194+
logger.debug(
195+
f"Replaced variants of scaled_dot_product_attention for {getattr(model_config, 'model_type', 'unknown')} model"
196+
)
197+
else:
198+
logger.debug(
199+
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
200+
)
201+
add_attn_mask_as_output = False
202+
if add_attn_mask_as_output:
203+
add_one_attn_mask_as_output(gm)
204+
return gm
110205

111-
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
112-
new_node.meta = copy.copy(node.meta)
113-
# Check if there's a getitem node following this attention node
114-
for user in list(node.users):
115-
if user.op == "call_function" and user.target == operator.getitem:
116-
# If the getitem is extracting the first element (the output tensor)
117-
if user.args[1] == 0:
118-
# Replace all uses of the getitem with the new attention node
119-
user.replace_all_uses_with(new_node)
120-
new_node.meta["val"] = new_node.meta["val"][0]
121-
# Replace all uses of the original node with the new node
122-
node.replace_all_uses_with(new_node)
123-
124-
gm.graph.erase_node(node)
125-
126-
# Clean up the graph
127-
clean_up_graph_after_modifications(gm)
128-
129-
logger.debug(
130-
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
206+
logger.info(
207+
f"Registered SDPA pass with model config: {getattr(model_config, 'model_type', 'unknown')}"
131208
)
132-
add_attn_mask_as_output = False
133-
if add_attn_mask_as_output:
134-
add_one_attn_mask_as_output(gm)
135-
return gm
136-
137-
138-
# try to add one of the attn_mask as output, so that I can actually see the shape and value in the generation phase.
139-
def add_one_attn_mask_as_output(gm: torch.fx.GraphModule):
140-
import torch.utils._pytree as pytree
141-
from cache_utils import create_random_output_tensors
142-
143-
attn_mask_node = None
144-
for node in gm.graph.nodes:
145-
if (
146-
node.op == "call_function"
147-
and node.target == torch.nn.functional.scaled_dot_product_attention
148-
):
149-
attn_mask_node = node.args[3]
150-
break
151-
152-
output_node = next(node for node in gm.graph.nodes if node.op == "output")
153-
154-
current_outputs = output_node.args[0]
155-
if isinstance(current_outputs, tuple):
156-
new_outputs = current_outputs + (attn_mask_node,)
157-
else:
158-
new_outputs = (current_outputs, attn_mask_node)
159-
output_node.args = new_outputs
160-
gm.graph.output(new_outputs)
161-
gm.graph.erase_node(output_node)
162-
163-
gm = clean_up_graph_after_modifications(gm)
164-
new_output_tensors = create_random_output_tensors(new_outputs)
165-
new_out_spec = pytree.tree_flatten(new_output_tensors)[1]
166-
gm._out_spec = new_out_spec
167-
return gm
209+
return replace_variants_of_sdpa_with_config

0 commit comments

Comments
 (0)