|
106 | 106 |
|
107 | 107 | import onnx |
108 | 108 | import torch |
109 | | -from torch.onnx import _type_utils, symbolic_helper |
| 109 | +from torch.onnx import symbolic_helper |
110 | 110 | from torch.onnx import symbolic_helper as sym_help |
111 | 111 | from torch.onnx._internal import jit_utils |
112 | 112 | from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask |
@@ -290,6 +290,11 @@ def scaled_dot_product_attention( |
290 | 290 | enable_gqa: bool = False, |
291 | 291 | ): |
292 | 292 | """Perform scaled dot product attention.""" |
| 293 | + if hasattr(torch.onnx, "_type_utils"): |
| 294 | + from torch.onnx import _type_utils |
| 295 | + else: |
| 296 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 297 | + |
293 | 298 | assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( |
294 | 299 | "is_causal and attn_mask cannot be set at the same time" |
295 | 300 | ) |
@@ -393,6 +398,11 @@ def export_fp8_mha( |
393 | 398 | """ |
394 | 399 | from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask |
395 | 400 |
|
| 401 | + if hasattr(torch.onnx, "_type_utils"): |
| 402 | + from torch.onnx import _type_utils |
| 403 | + else: |
| 404 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 405 | + |
396 | 406 | # Pass all arguments, including x, to the custom ONNX operator |
397 | 407 | assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), ( |
398 | 408 | "is_causal and attn_mask cannot be set at the same time" |
|
0 commit comments