|
107 | 107 | import onnx
|
108 | 108 | import torch
|
109 | 109 | from packaging.version import Version
|
110 |
| - |
111 |
| -if Version(torch.__version__) >= Version("2.9.0"): |
112 |
| - from torch.onnx._internal.torchscript_exporter import _type_utils |
113 |
| -else: |
114 |
| - from torch.onnx import _type_utils |
115 |
| - |
116 | 110 | from torch.onnx import symbolic_helper
|
117 | 111 | from torch.onnx import symbolic_helper as sym_help
|
118 | 112 | from torch.onnx._internal import jit_utils
|
@@ -297,6 +291,11 @@ def scaled_dot_product_attention(
|
297 | 291 | enable_gqa: bool = False,
|
298 | 292 | ):
|
299 | 293 | """Perform scaled dot product attention."""
|
| 294 | + if Version(torch.__version__) > Version("2.8.0"): |
| 295 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 296 | + else: |
| 297 | + from torch.onnx import _type_utils |
| 298 | + |
300 | 299 | assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
|
301 | 300 | "is_causal and attn_mask cannot be set at the same time"
|
302 | 301 | )
|
@@ -400,6 +399,11 @@ def export_fp8_mha(
|
400 | 399 | """
|
401 | 400 | from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
|
402 | 401 |
|
| 402 | + if Version(torch.__version__) > Version("2.8.0"): |
| 403 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 404 | + else: |
| 405 | + from torch.onnx import _type_utils |
| 406 | + |
403 | 407 | # Pass all arguments, including x, to the custom ONNX operator
|
404 | 408 | assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
|
405 | 409 | "is_causal and attn_mask cannot be set at the same time"
|
|
0 commit comments