|
106 | 106 |
|
107 | 107 | import onnx
|
108 | 108 | import torch
|
109 |
| -from packaging.version import Version |
110 |
| - |
111 |
| -if Version(torch.__version__) >= Version("2.9.0"): |
112 |
| - from torch.onnx._internal.torchscript_exporter import _type_utils |
113 |
| -else: |
114 |
| - from torch.onnx import _type_utils |
115 |
| - |
116 | 109 | from torch.onnx import symbolic_helper
|
117 | 110 | from torch.onnx import symbolic_helper as sym_help
|
118 | 111 | from torch.onnx._internal import jit_utils
|
@@ -297,6 +290,11 @@ def scaled_dot_product_attention(
|
297 | 290 | enable_gqa: bool = False,
|
298 | 291 | ):
|
299 | 292 | """Perform scaled dot product attention."""
|
| 293 | + if hasattr(torch.onnx, "_type_utils"): |
| 294 | + from torch.onnx import _type_utils |
| 295 | + else: |
| 296 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 297 | + |
300 | 298 | assert (not is_causal) or (is_causal and symbolic_helper._is_none(attn_mask)), (
|
301 | 299 | "is_causal and attn_mask cannot be set at the same time"
|
302 | 300 | )
|
@@ -400,6 +398,11 @@ def export_fp8_mha(
|
400 | 398 | """
|
401 | 399 | from torch.onnx.symbolic_opset14 import _attention_scale, _causal_attention_mask
|
402 | 400 |
|
| 401 | + if hasattr(torch.onnx, "_type_utils"): |
| 402 | + from torch.onnx import _type_utils |
| 403 | + else: |
| 404 | + from torch.onnx._internal.torchscript_exporter import _type_utils |
| 405 | + |
403 | 406 | # Pass all arguments, including x, to the custom ONNX operator
|
404 | 407 | assert (not is_causal) or (is_causal and sym_help._is_none(attn_mask)), (
|
405 | 408 | "is_causal and attn_mask cannot be set at the same time"
|
|
0 commit comments