|
18 | 18 | import inspect |
19 | 19 | import sys |
20 | 20 | import types |
| 21 | +import warnings |
21 | 22 | from typing import TYPE_CHECKING, Any, Callable |
22 | 23 |
|
23 | 24 | import torch |
24 | 25 | import transformers |
25 | | -from torch.onnx.symbolic_opset14 import ( |
26 | | - _attention_scale, |
27 | | - _causal_attention_mask, |
28 | | - _onnx_symbolic, |
29 | | - _type_utils, |
30 | | - jit_utils, |
31 | | - symbolic_helper, |
32 | | -) |
33 | 26 | from transformers.modeling_outputs import BaseModelOutput |
34 | 27 | from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet |
35 | 28 |
|
36 | | -from optimum.utils import is_diffusers_version, is_transformers_version, logging |
| 29 | +from optimum.utils import is_diffusers_version, is_torch_version, is_transformers_version, logging |
37 | 30 |
|
38 | 31 |
|
| 32 | +if is_torch_version("<", "2.9"): |
| 33 | + from torch.onnx.symbolic_opset14 import _onnx_symbolic, jit_utils, symbolic_helper |
| 34 | +else: |
| 35 | + from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import _onnx_symbolic, jit_utils, symbolic_helper |
| 36 | + |
39 | 37 | if is_transformers_version(">=", "4.44") and is_transformers_version("<", "4.50"): |
40 | 38 | from optimum.exporters.onnx._traceable_cache import TraceableCache |
41 | 39 | if is_transformers_version(">=", "4.54"): |
42 | 40 | from optimum.exporters.onnx._traceable_decorator import traceable_check_model_inputs |
43 | | - |
44 | 41 | if is_transformers_version(">=", "4.43") and is_transformers_version("<", "4.48"): |
45 | 42 | from transformers.models.clip.modeling_clip import CLIPAttention, CLIPSdpaAttention |
46 | 43 | if is_transformers_version(">=", "4.48"): |
@@ -79,76 +76,92 @@ def __ior_(g: jit_utils.GraphContext, self: torch._C.Value, other: torch._C.Valu |
79 | 76 | return g.op("Or", self, other) |
80 | 77 |
|
81 | 78 |
|
82 | | -@_onnx_symbolic("aten::scaled_dot_product_attention") |
83 | | -@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") |
84 | | -def scaled_dot_product_attention( |
85 | | - g: jit_utils.GraphContext, |
86 | | - query: torch._C.Value, |
87 | | - key: torch._C.Value, |
88 | | - value: torch._C.Value, |
89 | | - attn_mask: torch._C.Value | None = None, |
90 | | - dropout_p: float = 0.0, |
91 | | - is_causal: bool = False, |
92 | | - scale: torch._C.Value | None = None, |
93 | | - enable_gqa: bool = False, |
94 | | -): |
95 | | - assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( |
96 | | - "is_causal and attn_mask cannot be set at the same time" |
| 79 | +if is_torch_version("<", "2.9"): |
| 80 | + # this wad fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973 |
| 81 | + from torch.onnx.errors import OnnxExporterWarning |
| 82 | + from torch.onnx.symbolic_opset14 import ( |
| 83 | + _attention_scale, |
| 84 | + _causal_attention_mask, |
| 85 | + _onnx_symbolic, |
| 86 | + _type_utils, |
| 87 | + jit_utils, |
| 88 | + symbolic_helper, |
97 | 89 | ) |
98 | | - assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" |
99 | | - |
100 | | - if symbolic_helper._is_none(scale): |
101 | | - scale = _attention_scale(g, query) |
102 | | - |
103 | | - if is_causal: |
104 | | - attn_mask = _causal_attention_mask(g, query, key) |
105 | | - |
106 | | - # Swap the last two axes of key |
107 | | - # NOTE: onnx-script has different logic here, because the attribute perms in |
108 | | - # transpose needs list of ints |
109 | | - key_shape_builtin = symbolic_helper._get_tensor_rank(key) |
110 | | - key_transposed_axes = list(range(key_shape_builtin)) |
111 | | - key_transposed_axes[-1], key_transposed_axes[-2] = (key_transposed_axes[-2], key_transposed_axes[-1]) |
112 | | - key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) |
113 | | - |
114 | | - # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 |
115 | | - # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math |
116 | | - query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) |
117 | | - key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) |
118 | | - mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) |
119 | | - |
120 | | - if symbolic_helper._is_none(attn_mask): |
121 | | - mul_qk_add = mul_qk |
122 | | - attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) |
123 | | - elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL: |
124 | | - # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) |
125 | | - const_zero = g.op("Constant", value_t=torch.tensor([0.0])) |
126 | | - const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) |
127 | | - attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) |
128 | | - mul_qk_add = g.op("Add", mul_qk, attn_mask) |
129 | | - attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) |
130 | | - # when using scaled dot product attention with a boolean mask, we replace NaN values in attn_weight with 0.0 |
131 | | - attn_weight = g.op( |
132 | | - "Where", g.op("IsNaN", attn_weight), g.op("Constant", value_t=torch.tensor([0.0])), attn_weight |
133 | | - ) |
134 | | - elif _type_utils.JitScalarType.from_value(attn_mask) in ( |
135 | | - _type_utils.JitScalarType.FLOAT, |
136 | | - _type_utils.JitScalarType.HALF, |
137 | | - _type_utils.JitScalarType.BFLOAT16, |
138 | | - ): |
139 | | - mul_qk_add = g.op("Add", mul_qk, attn_mask) |
140 | | - attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) |
141 | | - else: |
142 | | - raise ValueError(f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}") |
143 | 90 |
|
144 | | - if dropout_p != 0: |
145 | | - attn_weight = g.op( |
146 | | - "Dropout", |
147 | | - attn_weight, |
148 | | - g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), |
| 91 | + warnings.filterwarnings("ignore", category=OnnxExporterWarning) |
| 92 | + |
| 93 | + @_onnx_symbolic("aten::scaled_dot_product_attention") |
| 94 | + @symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b") |
| 95 | + def scaled_dot_product_attention( |
| 96 | + g: jit_utils.GraphContext, |
| 97 | + query: torch._C.Value, |
| 98 | + key: torch._C.Value, |
| 99 | + value: torch._C.Value, |
| 100 | + attn_mask: torch._C.Value | None = None, |
| 101 | + dropout_p: float = 0.0, |
| 102 | + is_causal: bool = False, |
| 103 | + scale: torch._C.Value | None = None, |
| 104 | + enable_gqa: bool = False, |
| 105 | + ): |
| 106 | + assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( |
| 107 | + "is_causal and attn_mask cannot be set at the same time" |
149 | 108 | ) |
| 109 | + assert not enable_gqa, "conversion of scaled_dot_product_attention not implemented if enable_gqa is True" |
| 110 | + |
| 111 | + if symbolic_helper._is_none(scale): |
| 112 | + scale = _attention_scale(g, query) |
| 113 | + |
| 114 | + if is_causal: |
| 115 | + attn_mask = _causal_attention_mask(g, query, key) |
| 116 | + |
| 117 | + # Swap the last two axes of key |
| 118 | + # NOTE: onnx-script has different logic here, because the attribute perms in |
| 119 | + # transpose needs list of ints |
| 120 | + key_shape_builtin = symbolic_helper._get_tensor_rank(key) |
| 121 | + key_transposed_axes = list(range(key_shape_builtin)) |
| 122 | + key_transposed_axes[-1], key_transposed_axes[-2] = (key_transposed_axes[-2], key_transposed_axes[-1]) |
| 123 | + key_transposed = g.op("Transpose", key, perm_i=key_transposed_axes) |
| 124 | + |
| 125 | + # https://github.com/pytorch/pytorch/blob/12da0c70378b5be9135c6fda62a9863bce4a4818/aten/src/ATen/native/transformers/attention.cpp#L653 |
| 126 | + # Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for math |
| 127 | + query_scaled = g.op("Mul", query, g.op("Sqrt", scale)) |
| 128 | + key_transposed_scaled = g.op("Mul", key_transposed, g.op("Sqrt", scale)) |
| 129 | + mul_qk = g.op("MatMul", query_scaled, key_transposed_scaled) |
| 130 | + |
| 131 | + if symbolic_helper._is_none(attn_mask): |
| 132 | + mul_qk_add = mul_qk |
| 133 | + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) |
| 134 | + elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL: |
| 135 | + # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf')) |
| 136 | + const_zero = g.op("Constant", value_t=torch.tensor([0.0])) |
| 137 | + const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")])) |
| 138 | + attn_mask = g.op("Where", attn_mask, const_zero, const_neg_inf) |
| 139 | + mul_qk_add = g.op("Add", mul_qk, attn_mask) |
| 140 | + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) |
| 141 | + # when using scaled dot product attention with a boolean mask, we replace NaN values in attn_weight with 0.0 |
| 142 | + attn_weight = g.op( |
| 143 | + "Where", g.op("IsNaN", attn_weight), g.op("Constant", value_t=torch.tensor([0.0])), attn_weight |
| 144 | + ) |
| 145 | + elif _type_utils.JitScalarType.from_value(attn_mask) in ( |
| 146 | + _type_utils.JitScalarType.FLOAT, |
| 147 | + _type_utils.JitScalarType.HALF, |
| 148 | + _type_utils.JitScalarType.BFLOAT16, |
| 149 | + ): |
| 150 | + mul_qk_add = g.op("Add", mul_qk, attn_mask) |
| 151 | + attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1) |
| 152 | + else: |
| 153 | + raise ValueError(f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}") |
| 154 | + |
| 155 | + if dropout_p != 0: |
| 156 | + attn_weight = g.op( |
| 157 | + "Dropout", |
| 158 | + attn_weight, |
| 159 | + g.op("Constant", value_t=torch.tensor(dropout_p, dtype=torch.float)), |
| 160 | + ) |
| 161 | + |
| 162 | + return g.op("MatMul", attn_weight, value) |
150 | 163 |
|
151 | | - return g.op("MatMul", attn_weight, value) |
| 164 | + warnings.filterwarnings("default", category=OnnxExporterWarning) |
152 | 165 |
|
153 | 166 |
|
154 | 167 | def patch_everywhere(attribute_name: str, patch: Any, module_name_prefix: str | None = None): |
|
0 commit comments