Skip to content

Commit 4a7c985

Browse files
authored
Clean up torch.onnx internal imports (#86)
Replace internal imports with proper public apis --------- Signed-off-by: Justin Chu <[email protected]>
1 parent 08bd601 commit 4a7c985

File tree

1 file changed

+15
-29
lines changed

1 file changed

+15
-29
lines changed

optimum/exporters/onnx/model_patcher.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,22 +18,17 @@
1818
import inspect
1919
import sys
2020
import types
21-
import warnings
2221
from typing import TYPE_CHECKING, Any, Callable
2322

2423
import torch
2524
import transformers
25+
from torch.onnx import symbolic_helper
2626
from transformers.modeling_outputs import BaseModelOutput
2727
from transformers.models.speecht5.modeling_speecht5 import SpeechT5EncoderWithSpeechPrenet
2828

2929
from optimum.utils import is_diffusers_version, is_torch_version, is_transformers_version, logging
3030

3131

32-
if is_torch_version("<", "2.9"):
33-
from torch.onnx.symbolic_opset14 import _onnx_symbolic, jit_utils, symbolic_helper
34-
else:
35-
from torch.onnx._internal.torchscript_exporter.symbolic_opset14 import _onnx_symbolic, jit_utils, symbolic_helper
36-
3732
if is_transformers_version(">=", "4.44") and is_transformers_version("<", "4.50"):
3833
from optimum.exporters.onnx._traceable_cache import TraceableCache
3934
if is_transformers_version(">=", "4.54"):
@@ -70,30 +65,21 @@
7065
logger = logging.get_logger(__name__)
7166

7267

73-
@_onnx_symbolic("aten::__ior_")
7468
@symbolic_helper.parse_args("v", "v")
75-
def __ior_(g: jit_utils.GraphContext, self: torch._C.Value, other: torch._C.Value) -> torch._C.Value:
69+
def __ior_(g, self: torch._C.Value, other: torch._C.Value) -> torch._C.Value:
7670
return g.op("Or", self, other)
7771

7872

79-
if is_torch_version("<", "2.9"):
80-
# this wad fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973
81-
from torch.onnx.errors import OnnxExporterWarning
82-
from torch.onnx.symbolic_opset14 import (
83-
_attention_scale,
84-
_causal_attention_mask,
85-
_onnx_symbolic,
86-
_type_utils,
87-
jit_utils,
88-
symbolic_helper,
89-
)
73+
torch.onnx.register_custom_op_symbolic("aten::__ior__", __ior_, 14)
9074

91-
warnings.filterwarnings("ignore", category=OnnxExporterWarning)
75+
if is_torch_version("<", "2.9"):
76+
# this was fixed in torch in 2.9 https://github.com/pytorch/pytorch/pull/159973
77+
from torch.onnx import JitScalarType
78+
from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
9279

93-
@_onnx_symbolic("aten::scaled_dot_product_attention")
9480
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "b")
9581
def scaled_dot_product_attention(
96-
g: jit_utils.GraphContext,
82+
g,
9783
query: torch._C.Value,
9884
key: torch._C.Value,
9985
value: torch._C.Value,
@@ -131,7 +117,7 @@ def scaled_dot_product_attention(
131117
if symbolic_helper._is_none(attn_mask):
132118
mul_qk_add = mul_qk
133119
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
134-
elif _type_utils.JitScalarType.from_value(attn_mask) == _type_utils.JitScalarType.BOOL:
120+
elif JitScalarType.from_value(attn_mask) == JitScalarType.BOOL:
135121
# Turn the Boolean mask to float: attn_mask.masked_fill(not attn_mask, -float('inf'))
136122
const_zero = g.op("Constant", value_t=torch.tensor([0.0]))
137123
const_neg_inf = g.op("Constant", value_t=torch.tensor([-float("inf")]))
@@ -142,15 +128,15 @@ def scaled_dot_product_attention(
142128
attn_weight = g.op(
143129
"Where", g.op("IsNaN", attn_weight), g.op("Constant", value_t=torch.tensor([0.0])), attn_weight
144130
)
145-
elif _type_utils.JitScalarType.from_value(attn_mask) in (
146-
_type_utils.JitScalarType.FLOAT,
147-
_type_utils.JitScalarType.HALF,
148-
_type_utils.JitScalarType.BFLOAT16,
131+
elif JitScalarType.from_value(attn_mask) in (
132+
JitScalarType.FLOAT,
133+
JitScalarType.HALF,
134+
JitScalarType.BFLOAT16,
149135
):
150136
mul_qk_add = g.op("Add", mul_qk, attn_mask)
151137
attn_weight = g.op("Softmax", mul_qk_add, axis_i=-1)
152138
else:
153-
raise ValueError(f"Unsupported type for attn_mask: {_type_utils.JitScalarType.from_value(attn_mask)}")
139+
raise ValueError(f"Unsupported type for attn_mask: {JitScalarType.from_value(attn_mask)}")
154140

155141
if dropout_p != 0:
156142
attn_weight = g.op(
@@ -161,7 +147,7 @@ def scaled_dot_product_attention(
161147

162148
return g.op("MatMul", attn_weight, value)
163149

164-
warnings.filterwarnings("default", category=OnnxExporterWarning)
150+
torch.onnx.register_custom_op_symbolic("aten::scaled_dot_product_attention", scaled_dot_product_attention, 14)
165151

166152

167153
def patch_everywhere(attribute_name: str, patch: Any, module_name_prefix: str | None = None):

0 commit comments

Comments
 (0)