1
1
import copy
2
2
import logging
3
3
import operator
4
- from typing import Callable , Sequence , Tuple
4
+ from typing import Any , Callable , Dict , Optional , Sequence , Tuple , Type
5
5
6
6
import torch
7
7
from torch_tensorrt .dynamo ._settings import CompilationSettings
13
13
from torch_tensorrt .dynamo .lowering .passes .pass_utils import (
14
14
clean_up_graph_after_modifications ,
15
15
)
16
- from transformers import Gemma3TextConfig
16
+ from transformers import AutoConfig , Gemma3TextConfig
17
17
18
18
from .sdpa_converter import *
19
19
34
34
torch .ops .aten ._scaled_dot_product_flash_attention .default ,
35
35
}
36
36
37
+ from torch_tensorrt .dynamo .lowering .passes ._aten_lowering_pass import (
38
+ get_lowering_pass_config ,
39
+ )
40
+
37
41
38
- def register_sdpa_pass_with_model_config (index : int = 0 , model_config = None ):
39
- """
40
- Register the SDPA replacement pass with a specific model configuration.
42
+ def _process_sdpa_node (
43
+ gm : torch .fx .GraphModule ,
44
+ node : torch .fx .Node ,
45
+ settings : CompilationSettings ,
46
+ sliding_window_size : Optional [int ] = None ,
47
+ use_gqa : bool = False ,
48
+ ) -> torch .fx .GraphModule :
49
+ """Helper function to process SDPA nodes with common logic."""
50
+
51
+ if node .target == torch .ops .aten ._scaled_dot_product_efficient_attention .default :
52
+ if len (node .args ) == 7 :
53
+ (
54
+ query ,
55
+ key ,
56
+ value ,
57
+ attn_mask ,
58
+ compute_log_sumexp ,
59
+ dropout_p ,
60
+ is_causal ,
61
+ ) = node .args
62
+ elif len (node .args ) == 5 :
63
+ query , key , value , attn_mask , is_causal = node .args
64
+ dropout_p = 0.0
65
+ else :
66
+ raise ValueError (
67
+ f"Unexpected number of arguments for { node .target } in the graph"
68
+ )
69
+ elif node .target == torch .ops .aten ._scaled_dot_product_flash_attention .default :
70
+ if len (node .args ) == 6 :
71
+ (
72
+ query ,
73
+ key ,
74
+ value ,
75
+ dropout_p ,
76
+ is_causal ,
77
+ return_debug_mask ,
78
+ ) = node .args
79
+ elif len (node .args ) == 5 :
80
+ query , key , value , dropout_p , is_causal = node .args
81
+ elif len (node .args ) == 3 :
82
+ query , key , value = node .args
83
+ dropout_p = 0.0
84
+ is_causal = True
85
+ else :
86
+ raise ValueError (
87
+ f"Unexpected number of arguments for { node .target } in the graph"
88
+ )
89
+ else :
90
+ return gm
41
91
42
- Args:
43
- model_config: The model configuration object (e.g., from transformers.AutoConfig)
44
- index: Position in the lowering pass list (default: 0)
92
+ # Always set causal to True and generate attn_mask inside the sdpa operator
93
+ attn_mask = None
94
+ is_causal = True
95
+ dropout_p = 0.0
45
96
46
- Example:
47
- from transformers import AutoConfig
48
- config = AutoConfig.from_pretrained("microsoft/DialoGPT-medium")
49
- register_sdpa_pass_with_model_config(config)
50
- """
51
- from torch_tensorrt .dynamo .lowering .passes ._aten_lowering_pass import (
52
- _aten_lowering_pass ,
53
- _remove_lowering_pass ,
97
+ logger .warning (
98
+ f"SDPA converter configuration: attn_mask={ attn_mask } , dropout_p={ dropout_p } , "
99
+ f"is_causal={ is_causal } , sliding_window_size={ sliding_window_size } , use_gqa={ use_gqa } "
54
100
)
55
101
56
- # Create a new pass with the model configuration
57
- @_aten_lowering_pass (index = index , model_config = model_config )
58
- def replace_variants_of_sdpa_with_config (
59
- gm : torch .fx .GraphModule , settings : CompilationSettings
60
- ) -> torch .fx .GraphModule :
61
- """Replace scaled_dot_product_attention with model-specific configuration"""
102
+ modified_input_args = (
103
+ query ,
104
+ key ,
105
+ value ,
106
+ attn_mask ,
107
+ dropout_p ,
108
+ is_causal ,
109
+ )
62
110
63
- # Access the model configuration from the decorator parameters
64
- from torch_tensorrt .dynamo .lowering .passes ._aten_lowering_pass import (
65
- get_lowering_pass_config ,
111
+ # Create a new node with torch.nn.functional.scaled_dot_product_attention
112
+ with gm .graph .inserting_after (node ):
113
+ new_node = gm .graph .call_function (
114
+ torch .nn .functional .scaled_dot_product_attention ,
115
+ args = modified_input_args ,
116
+ kwargs = {
117
+ "scale" : node .kwargs .get ("scale" , None ),
118
+ "use_fp32_acc" : settings .use_fp32_acc ,
119
+ "sliding_window_size" : sliding_window_size ,
120
+ },
66
121
)
67
122
68
- config = get_lowering_pass_config (replace_variants_of_sdpa_with_config )
123
+ # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
124
+ new_node .meta = copy .copy (node .meta )
125
+ # Check if there's a getitem node following this attention node
126
+ for user in list (node .users ):
127
+ if user .op == "call_function" and user .target == operator .getitem :
128
+ # If the getitem is extracting the first element (the output tensor)
129
+ if user .args [1 ] == 0 :
130
+ # Replace all uses of the getitem with the new attention node
131
+ user .replace_all_uses_with (new_node )
132
+ new_node .meta ["val" ] = new_node .meta ["val" ][0 ]
133
+ # Replace all uses of the original node with the new node
134
+ node .replace_all_uses_with (new_node )
69
135
70
- model_config = config .get ("model_config" , None )
71
- layer_types = []
136
+ gm .graph .erase_node (node )
137
+ return gm
138
+
139
+
140
+ def register_gemma3_sdpa_pass (index : int = 0 , model_config : Any = None ) -> None :
141
+ @_aten_lowering_pass (index = index , model_config = model_config )
142
+ def gemma3_sdpa_pass (
143
+ gm : torch .fx .GraphModule , settings : CompilationSettings
144
+ ) -> torch .fx .GraphModule :
145
+ """SDPA pass specifically for Gemma3 models with sliding window attention."""
146
+ config = get_lowering_pass_config (gemma3_sdpa_pass )
72
147
sliding_window = None
73
- # Extract model-specific parameters
74
- if model_config is not None :
75
- if isinstance (model_config , Gemma3TextConfig ):
76
- sliding_window = getattr (model_config , "sliding_window" , None )
77
- layer_types = getattr (model_config , "layer_types" , None )
78
- logger .info (f"Model config: { sliding_window = } { layer_types = } " )
79
- else :
148
+ layer_types = None
149
+ model_config = config .get ("model_config" , None )
150
+ if not isinstance (model_config , Gemma3TextConfig ):
80
151
logger .warning (
81
- "No model configuration provided, using default SDPA replacement behavior"
152
+ f"Expected Gemma3TextConfig, got { type (model_config )} , will use default SDPA replacement instead"
153
+ )
154
+ else :
155
+ sliding_window = getattr (model_config , "sliding_window" , None )
156
+ layer_types = getattr (model_config , "layer_types" , None )
157
+ logger .debug (
158
+ f"got Gemma3 config: sliding_window={ sliding_window } , layer_types={ layer_types } "
82
159
)
160
+
83
161
index = 0
84
162
for node in gm .graph .nodes :
85
163
if node .op == "call_function" and node .target in REPLACEABLE_ATEN_OPS :
@@ -94,116 +172,37 @@ def replace_variants_of_sdpa_with_config(
94
172
sliding_window_size = sliding_window
95
173
index += 1
96
174
97
- if (
98
- node .target
99
- == torch .ops .aten ._scaled_dot_product_efficient_attention .default
100
- ):
101
- if len (node .args ) == 7 :
102
- (
103
- query ,
104
- key ,
105
- value ,
106
- attn_mask ,
107
- compute_log_sumexp ,
108
- dropout_p ,
109
- is_causal ,
110
- ) = node .args
111
- elif len (node .args ) == 5 :
112
- query , key , value , attn_mask , is_causal = node .args
113
- dropout_p = 0.0
114
-
115
- else :
116
- raise ValueError (
117
- f"Unexpected number of arguments for { node .target } in the graph"
118
- )
119
- elif (
120
- node .target
121
- == torch .ops .aten ._scaled_dot_product_flash_attention .default
122
- ):
123
- if len (node .args ) == 6 :
124
- (
125
- query ,
126
- key ,
127
- value ,
128
- dropout_p ,
129
- is_causal ,
130
- return_debug_mask ,
131
- ) = node .args
132
- if len (node .args ) == 5 :
133
- query , key , value , dropout_p , is_causal = node .args
134
- elif len (node .args ) == 3 :
135
- query , key , value = node .args
136
- dropout_p = 0.0
137
- is_causal = True
138
- else :
139
- raise ValueError (
140
- f"Unexpected number of arguments for { node .target } in the graph"
141
- )
142
-
143
- # always set_causal to True and generate attn_mask inside the sdpa operator, do not use the attn_mask from the transformers.
144
- attn_mask = None
145
- is_causal = True
146
- dropout_p = 0.0
147
-
148
- logger .warning (
149
- f"This current version of SDPA converter only supports { attn_mask = } , { dropout_p = } and { is_causal = } and { sliding_window_size = } configuration. This could cause issues with accuracy for models with different configurations."
150
- )
151
- modified_input_args = (
152
- query ,
153
- key ,
154
- value ,
155
- attn_mask ,
156
- dropout_p ,
157
- is_causal ,
175
+ # Process the node
176
+ logger .debug (
177
+ f"Applying Gemma3-specific SDPA replacement with { node .name = } , { node .target = } , { sliding_window_size = } "
158
178
)
159
- # Create a new node with torch.nn.functional.scaled_dot_product_attention
160
- # The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale
161
- with gm .graph .inserting_after (node ):
162
- new_node = gm .graph .call_function (
163
- torch .nn .functional .scaled_dot_product_attention ,
164
- args = modified_input_args ,
165
- kwargs = {
166
- "scale" : node .kwargs .get ("scale" , None ),
167
- "use_fp32_acc" : settings .use_fp32_acc ,
168
- "sliding_window_size" : sliding_window_size ,
169
- },
170
- )
171
-
172
- # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
173
- new_node .meta = copy .copy (node .meta )
174
- # Check if there's a getitem node following this attention node
175
- for user in list (node .users ):
176
- if (
177
- user .op == "call_function"
178
- and user .target == operator .getitem
179
- ):
180
- # If the getitem is extracting the first element (the output tensor)
181
- if user .args [1 ] == 0 :
182
- # Replace all uses of the getitem with the new attention node
183
- user .replace_all_uses_with (new_node )
184
- new_node .meta ["val" ] = new_node .meta ["val" ][0 ]
185
- # Replace all uses of the original node with the new node
186
- node .replace_all_uses_with (new_node )
187
-
188
- gm .graph .erase_node (node )
189
-
190
- # Clean up the graph
179
+ gm = _process_sdpa_node (gm , node , settings , sliding_window_size )
180
+
191
181
clean_up_graph_after_modifications (gm )
182
+ logger .debug ("Applied Gemma3-specific SDPA replacement" )
183
+ return gm
192
184
193
- if model_config :
194
- logger .debug (
195
- f"Replaced variants of scaled_dot_product_attention for { getattr (model_config , 'model_type' , 'unknown' )} model"
196
- )
197
- else :
198
- logger .debug (
199
- "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
200
- )
201
- add_attn_mask_as_output = False
202
- if add_attn_mask_as_output :
203
- add_one_attn_mask_as_output (gm )
185
+
186
+ def register_default_sdpa_pass (index : int = 0 , model_config : Any = None ) -> None :
187
+ @_aten_lowering_pass (index = index , model_config = model_config )
188
+ def default_sdpa_pass (
189
+ gm : torch .fx .GraphModule ,
190
+ settings : CompilationSettings ,
191
+ ) -> torch .fx .GraphModule :
192
+ """Default SDPA pass for models without specific implementations."""
193
+
194
+ for node in gm .graph .nodes :
195
+ if node .op == "call_function" and node .target in REPLACEABLE_ATEN_OPS :
196
+ # Process the node with default logic
197
+ gm = _process_sdpa_node (gm , node , settings )
198
+
199
+ clean_up_graph_after_modifications (gm )
200
+ logger .debug ("Applied default SDPA replacement" )
204
201
return gm
205
202
206
- logger .info (
207
- f"Registered SDPA pass with model config: { getattr (model_config , 'model_type' , 'unknown' )} "
208
- )
209
- return replace_variants_of_sdpa_with_config
203
+
204
+ # Global registry for SDPA passes
205
+ _SDPA_MAPPING : Dict [str , Callable ] = {
206
+ "google/gemma-3-1b-it" : register_gemma3_sdpa_pass ,
207
+ "default" : register_default_sdpa_pass ,
208
+ }
0 commit comments