You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
f"This current version of SDPA converter only supports attn_mask = {attn_mask}, dropout_p = {dropout_p} and is_causal = {is_causal} configuration. This could cause issues with accuracy for models with different configurations."
81
+
"No model configuration provided, using default SDPA replacement behavior"
f"Unexpected number of arguments for {node.target} in the graph"
141
+
)
142
+
143
+
# always set_causal to True and generate attn_mask inside the sdpa operator, do not use the attn_mask from the transformers.
144
+
attn_mask=None
145
+
is_causal=True
146
+
dropout_p=0.0
147
+
148
+
logger.warning(
149
+
f"This current version of SDPA converter only supports {attn_mask=}, {dropout_p=} and {is_causal=} and {sliding_window_size=} configuration. This could cause issues with accuracy for models with different configurations."
109
150
)
151
+
modified_input_args= (
152
+
query,
153
+
key,
154
+
value,
155
+
attn_mask,
156
+
dropout_p,
157
+
is_causal,
158
+
)
159
+
# Create a new node with torch.nn.functional.scaled_dot_product_attention
160
+
# The input args is (query, key, value, attn_mask, dropout_p, is_causal). kwargs has scale
161
+
withgm.graph.inserting_after(node):
162
+
new_node=gm.graph.call_function(
163
+
torch.nn.functional.scaled_dot_product_attention,
164
+
args=modified_input_args,
165
+
kwargs={
166
+
"scale": node.kwargs.get("scale", None),
167
+
"use_fp32_acc": settings.use_fp32_acc,
168
+
"sliding_window_size": sliding_window_size,
169
+
},
170
+
)
171
+
172
+
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
173
+
new_node.meta=copy.copy(node.meta)
174
+
# Check if there's a getitem node following this attention node
175
+
foruserinlist(node.users):
176
+
if (
177
+
user.op=="call_function"
178
+
anduser.target==operator.getitem
179
+
):
180
+
# If the getitem is extracting the first element (the output tensor)
181
+
ifuser.args[1] ==0:
182
+
# Replace all uses of the getitem with the new attention node
183
+
user.replace_all_uses_with(new_node)
184
+
new_node.meta["val"] =new_node.meta["val"][0]
185
+
# Replace all uses of the original node with the new node
186
+
node.replace_all_uses_with(new_node)
187
+
188
+
gm.graph.erase_node(node)
189
+
190
+
# Clean up the graph
191
+
clean_up_graph_after_modifications(gm)
192
+
193
+
ifmodel_config:
194
+
logger.debug(
195
+
f"Replaced variants of scaled_dot_product_attention for {getattr(model_config, 'model_type', 'unknown')} model"
196
+
)
197
+
else:
198
+
logger.debug(
199
+
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
200
+
)
201
+
add_attn_mask_as_output=False
202
+
ifadd_attn_mask_as_output:
203
+
add_one_attn_mask_as_output(gm)
204
+
returngm
110
205
111
-
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
112
-
new_node.meta=copy.copy(node.meta)
113
-
# Check if there's a getitem node following this attention node
0 commit comments