|
106 | 106 |
|
107 | 107 | import onnx
|
108 | 108 | import torch
|
109 |
| -from torch.onnx import _type_utils, symbolic_helper |
| 109 | +from torch.onnx import symbolic_helper |
110 | 110 | from torch.onnx import symbolic_helper as sym_help
|
111 | 111 | from torch.onnx._internal import jit_utils
|
112 | 112 | from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
|
@@ -290,6 +290,11 @@ def scaled_dot_product_attention(
|
290 | 290 | enable_gqa: bool = False,
|
291 | 291 | ):
|
292 | 292 | """Perform scaled dot product attention."""
|
| 293 | + if hasattr(torch.onnx, "_type_utils"): |
| 294 | + from torch.onnx import _type_utils |
| 295 | + else: |
| 296 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 297 | + |
293 | 298 | assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
|
294 | 299 | "is_causal and attn_mask cannot be set at the same time"
|
295 | 300 | )
|
@@ -393,6 +398,11 @@ def export_fp8_mha(
|
393 | 398 | """
|
394 | 399 | from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
|
395 | 400 |
|
| 401 | + if hasattr(torch.onnx, "_type_utils"): |
| 402 | + from torch.onnx import _type_utils |
| 403 | + else: |
| 404 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 405 | + |
396 | 406 | # Pass all arguments, including x, to the custom ONNX operator
|
397 | 407 | assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
|
398 | 408 | "is_causal and attn_mask cannot be set at the same time"
|
|
0 commit comments