|
107 | 107 | import onnx |
108 | 108 | import torch |
109 | 109 | from packaging.version import Version |
110 | | - |
111 | | -if Version(torch.__version__) >= Version("2.9.0"): |
112 | | - from torch.onnx._internal.torchscript_exporter import _type_utils |
113 | | -else: |
114 | | - from torch.onnx import _type_utils |
115 | | - |
116 | 110 | from torch.onnx import symbolic_helper |
117 | 111 | from torch.onnx import symbolic_helper as sym_help |
118 | 112 | from torch.onnx._internal import jit_utils |
@@ -297,6 +291,11 @@ def scaled_dot_product_attention( |
297 | 291 | enable_gqa: bool = False, |
298 | 292 | ): |
299 | 293 | """Perform scaled dot product attention.""" |
| 294 | + if Version(torch.__version__) > Version("2.8.0"): |
| 295 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 296 | + else: |
| 297 | + from torch.onnx import _type_utils |
| 298 | + |
300 | 299 | assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), ( |
301 | 300 | "is_causal and attn_mask cannot be set at the same time" |
302 | 301 | ) |
@@ -400,6 +399,11 @@ def export_fp8_mha( |
400 | 399 | """ |
401 | 400 | from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask |
402 | 401 |
|
| 402 | + if Version(torch.__version__) > Version("2.8.0"): |
| 403 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 404 | + else: |
| 405 | + from torch.onnx import _type_utils |
| 406 | + |
403 | 407 | # Pass all arguments, including x, to the custom ONNX operator |
404 | 408 | assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), ( |
405 | 409 | "is_causal and attn_mask cannot be set at the same time" |
|
0 commit comments