1818import inspect
1919import sys
2020import types
21- import warnings
2221from typing import TYPE_CHECKING , Any , Callable
2322
2423import torch
2524import transformers
25+ from torch .onnx import symbolic_helper
2626from transformers .modeling_outputs import BaseModelOutput
2727from transformers .models .speecht5 .modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
2828
2929from optimum .utils import is_diffusers_version , is_torch_version , is_transformers_version , logging
3030
3131
32- if is_torch_version ("<" , "2.9" ):
33- from torch .onnx .symbolic_opset14 import _onnx_symbolic , jit_utils , symbolic_helper
34- else :
35- from torch .onnx ._internal .torchscript_exporter .symbolic_opset14 import _onnx_symbolic , jit_utils , symbolic_helper
36-
3732if is_transformers_version (">=" , "4.44" ) and is_transformers_version ("<" , "4.50" ):
3833 from optimum .exporters .onnx ._traceable_cache import TraceableCache
3934if is_transformers_version (">=" , "4.54" ):
7065logger = logging .get_logger (__name__ )
7166
7267
73- @_onnx_symbolic ("aten::__ior_" )
7468@symbolic_helper .parse_args ("v" , "v" )
75- def __ior_ (g : jit_utils . GraphContext , self : torch ._C .Value , other : torch ._C .Value ) -> torch ._C .Value :
69+ def __ior_ (g , self : torch ._C .Value , other : torch ._C .Value ) -> torch ._C .Value :
7670 return g .op ("Or" , self , other )
7771
7872
79- if is_torch_version ("<" , "2.9" ):
80- # this wad fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973
81- from torch .onnx .errors import OnnxExporterWarning
82- from torch .onnx .symbolic_opset14 import (
83- _attention_scale ,
84- _causal_attention_mask ,
85- _onnx_symbolic ,
86- _type_utils ,
87- jit_utils ,
88- symbolic_helper ,
89- )
73+ torch .onnx .register_custom_op_symbolic ("aten::__ior__" , __ior_ , 14 )
9074
91- warnings .filterwarnings ("ignore" , category = OnnxExporterWarning )
75+ if is_torch_version ("<" , "2.9" ):
76+ # this was fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973
77+ from torch .onnx import JitScalarType
78+ from torch .onnx .symbolic_opset14 import _attention_scale , _causal_attention_mask
9279
93- @_onnx_symbolic ("aten::scaled_dot_product_attention" )
9480 @symbolic_helper .parse_args ("v" , "v" , "v" , "v" , "f" , "b" , "v" , "b" )
9581 def scaled_dot_product_attention (
96- g : jit_utils . GraphContext ,
82+ g ,
9783 query : torch ._C .Value ,
9884 key : torch ._C .Value ,
9985 value : torch ._C .Value ,
@@ -131,7 +117,7 @@ def scaled_dot_product_attention(
131117 if symbolic_helper ._is_none (attn_mask ):
132118 mul_qk_add = mul_qk
133119 attn_weight = g .op ("Softmax" , mul_qk_add , axis_i = - 1 )
134- elif _type_utils . JitScalarType .from_value (attn_mask ) == _type_utils . JitScalarType .BOOL :
120+ elif JitScalarType .from_value (attn_mask ) == JitScalarType .BOOL :
135121 # Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
136122 const_zero = g .op ("Constant" , value_t = torch .tensor ([0.0 ]))
137123 const_neg_inf = g .op ("Constant" , value_t = torch .tensor ([- float ("inf" )]))
@@ -142,15 +128,15 @@ def scaled_dot_product_attention(
142128 attn_weight = g .op (
143129 "Where" , g .op ("IsNaN" , attn_weight ), g .op ("Constant" , value_t = torch .tensor ([0.0 ])), attn_weight
144130 )
145- elif _type_utils . JitScalarType .from_value (attn_mask ) in (
146- _type_utils . JitScalarType .FLOAT ,
147- _type_utils . JitScalarType .HALF ,
148- _type_utils . JitScalarType .BFLOAT16 ,
131+ elif JitScalarType .from_value (attn_mask ) in (
132+ JitScalarType .FLOAT ,
133+ JitScalarType .HALF ,
134+ JitScalarType .BFLOAT16 ,
149135 ):
150136 mul_qk_add = g .op ("Add" , mul_qk , attn_mask )
151137 attn_weight = g .op ("Softmax" , mul_qk_add , axis_i = - 1 )
152138 else :
153- raise ValueError (f"Unsupported type for attn_mask: { _type_utils . JitScalarType .from_value (attn_mask )} " )
139+ raise ValueError (f"Unsupported type for attn_mask: { JitScalarType .from_value (attn_mask )} " )
154140
155141 if dropout_p != 0 :
156142 attn_weight = g .op (
@@ -161,7 +147,7 @@ def scaled_dot_product_attention(
161147
162148 return g .op ("MatMul" , attn_weight , value )
163149
164- warnings . filterwarnings ( "default " , category = OnnxExporterWarning )
150+ torch . onnx . register_custom_op_symbolic ( "aten::scaled_dot_product_attention " , scaled_dot_product_attention , 14 )
165151
166152
167153def patch_everywhere (attribute_name : str , patch : Any , module_name_prefix : str | None = None ):
0 commit comments