Skip to content

Commit 3228c24

Browse files
Torch 2.9 support (#82)
fix #83 and #84
1 parent 92c5675 commit 3228c24

File tree

2 files changed

+90
-76
lines changed

2 files changed

+90
-76
lines changed

optimum/exporters/onnx/convert.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,7 @@ def remap(value):
567567
dynamic_axes=dynamix_axes,
568568
do_constant_folding=do_constant_folding,
569569
opset_version=opset,
570+
dynamo=False, # torch dynamo not yet supported
570571
)
571572

572573
# check if external data was exported

optimum/exporters/onnx/model_patcher.py

Lines changed: 89 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,26 @@
1818
import inspect
1919
import sys
2020
import types
21+
import warnings
2122
from typing import TYPE_CHECKING, Any, Callable
2223

2324
import torch
2425
import transformers
25-
from torch.onnx.symbolic_opset14 import (
26-
_attention_scale,
27-
_causal_attention_mask,
28-
_onnx_symbolic,
29-
_type_utils,
30-
jit_utils,
31-
symbolic_helper,
32-
)
3326
from transformers.modeling_outputs import BaseModelOutput
3427
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
3528

36-
from optimum.utils import is_diffusers_version, is_transformers_version, logging
29+
from optimum.utils import is_diffusers_version, is_torch_version, is_transformers_version, logging
3730

3831

32+
if is_torch_version("<", "2.9"):
33+
from torch.onnx.symbolic_opset14 import _onnx_symbolic, jit_utils, symbolic_helper
34+
else:
35+
from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import _onnx_symbolic, jit_utils, symbolic_helper
36+
3937
if is_transformers_version(">=", "4.44") and is_transformers_version("<", "4.50"):
4038
from optimum.exporters.onnx._traceable_cache import TraceableCache
4139
if is_transformers_version(">=", "4.54"):
4240
from optimum.exporters.onnx._traceable_decorator import traceable_check_model_inputs
43-
4441
if is_transformers_version(">=", "4.43") and is_transformers_version("<", "4.48"):
4542
from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention
4643
if is_transformers_version(">=", "4.48"):
@@ -79,76 +76,92 @@ def __ior_(g: jit_utils.GraphContext, self: torch._C.Value, other: torch._C.Valu
7976
return g.op("Or", self, other)
8077

8178

82-
@_onnx_symbolic("aten::scaled_dot_product_attention")
83-
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b")
84-
def scaled_dot_product_attention(
85-
g: jit_utils.GraphContext,
86-
query: torch._C.Value,
87-
key: torch._C.Value,
88-
value: torch._C.Value,
89-
attn_mask: torch._C.Value | None = None,
90-
dropout_p: float = 0.0,
91-
is_causal: bool = False,
92-
scale: torch._C.Value | None = None,
93-
enable_gqa: bool = False,
94-
):
95-
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
96-
"is_causal and attn_mask cannot be set at the same time"
79+
if is_torch_version("<", "2.9"):
80+
# this wad fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973
81+
from torch.onnx.errors import OnnxExporterWarning
82+
from torch.onnx.symbolic_opset14 import (
83+
_attention_scale,
84+
_causal_attention_mask,
85+
_onnx_symbolic,
86+
_type_utils,
87+
jit_utils,
88+
symbolic_helper,
9789
)
98-
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
99-
100-
if symbolic_helper._is_none(scale):
101-
scale = _attention_scale(g, query)
102-
103-
if is_causal:
104-
attn_mask = _causal_attention_mask(g, query, key)
105-
106-
# Swap the last two axes of key
107-
# NOTE: onnx-script has different logic here, because the attribute perms in
108-
# transpose needs list of ints
109-
key_shape_builtin = symbolic_helper._get_tensor_rank(key)
110-
key_transposed_axes = list(range(key_shape_builtin))
111-
key_transposed_axes[-1], key_transposed_axes[-2] = (key_transposed_axes[-2], key_transposed_axes[-1])
112-
key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
113-
114-
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
115-
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
116-
query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
117-
key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
118-
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
119-
120-
if symbolic_helper._is_none(attn_mask):
121-
mul_qk_add = mul_qk
122-
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
123-
elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL:
124-
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
125-
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
126-
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
127-
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
128-
mul_qk_add = g.op("Add", mul_qk, attn_mask)
129-
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
130-
# when using scaled dot product attention with a boolean mask, we replace NaN values in attn_weight with 0.0
131-
attn_weight = g.op(
132-
"Where", g.op("IsNaN", attn_weight), g.op("Constant", value_t=torch.tensor([0.0])), attn_weight
133-
)
134-
elif _type_utils.JitScalarType.from_value(attn_mask) in (
135-
_type_utils.JitScalarType.FLOAT,
136-
_type_utils.JitScalarType.HALF,
137-
_type_utils.JitScalarType.BFLOAT16,
138-
):
139-
mul_qk_add = g.op("Add", mul_qk, attn_mask)
140-
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
141-
else:
142-
raise ValueError(f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}")
14390

144-
if dropout_p != 0:
145-
attn_weight = g.op(
146-
"Dropout",
147-
attn_weight,
148-
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
91+
warnings.filterwarnings("ignore", category=OnnxExporterWarning)
92+
93+
@_onnx_symbolic("aten::scaled_dot_product_attention")
94+
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b")
95+
def scaled_dot_product_attention(
96+
g: jit_utils.GraphContext,
97+
query: torch._C.Value,
98+
key: torch._C.Value,
99+
value: torch._C.Value,
100+
attn_mask: torch._C.Value | None = None,
101+
dropout_p: float = 0.0,
102+
is_causal: bool = False,
103+
scale: torch._C.Value | None = None,
104+
enable_gqa: bool = False,
105+
):
106+
assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
107+
"is_causal and attn_mask cannot be set at the same time"
149108
)
109+
assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True"
110+
111+
if symbolic_helper._is_none(scale):
112+
scale = _attention_scale(g, query)
113+
114+
if is_causal:
115+
attn_mask = _causal_attention_mask(g, query, key)
116+
117+
# Swap the last two axes of key
118+
# NOTE: onnx-script has different logic here, because the attribute perms in
119+
# transpose needs list of ints
120+
key_shape_builtin = symbolic_helper._get_tensor_rank(key)
121+
key_transposed_axes = list(range(key_shape_builtin))
122+
key_transposed_axes[-1], key_transposed_axes[-2] = (key_transposed_axes[-2], key_transposed_axes[-1])
123+
key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes)
124+
125+
# https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653
126+
# Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math
127+
query_scaled = g.op("Mul", query, g.op("Sqrt", scale))
128+
key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale))
129+
mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled)
130+
131+
if symbolic_helper._is_none(attn_mask):
132+
mul_qk_add = mul_qk
133+
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
134+
elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL:
135+
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
136+
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
137+
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
138+
attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf)
139+
mul_qk_add = g.op("Add", mul_qk, attn_mask)
140+
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
141+
# when using scaled dot product attention with a boolean mask, we replace NaN values in attn_weight with 0.0
142+
attn_weight = g.op(
143+
"Where", g.op("IsNaN", attn_weight), g.op("Constant", value_t=torch.tensor([0.0])), attn_weight
144+
)
145+
elif _type_utils.JitScalarType.from_value(attn_mask) in (
146+
_type_utils.JitScalarType.FLOAT,
147+
_type_utils.JitScalarType.HALF,
148+
_type_utils.JitScalarType.BFLOAT16,
149+
):
150+
mul_qk_add = g.op("Add", mul_qk, attn_mask)
151+
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
152+
else:
153+
raise ValueError(f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}")
154+
155+
if dropout_p != 0:
156+
attn_weight = g.op(
157+
"Dropout",
158+
attn_weight,
159+
g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)),
160+
)
161+
162+
return g.op("MatMul", attn_weight, value)
150163

151-
return g.op("MatMul", attn_weight, value)
164+
warnings.filterwarnings("default", category=OnnxExporterWarning)
152165

153166

154167
def patch_everywhere(attribute_name: str, patch: Any, module_name_prefix: str | None = None):

0 commit comments

Comments
 (0)