Skip to content

Commit e00f91d

Browse files
JP-Amboagenickfraser
authored andcommitted
Fix (export/onnx): import GLOBALS from correct location depending on torch version (#1398)
1 parent 4653449 commit e00f91d

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/brevitas/export/onnx/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
22
# SPDX-License-Identifier: BSD-3-Clause
3+
from packaging import version
4+
5+
from brevitas import torch_version
36

47

58
def onnx_export_opset():
@@ -9,7 +12,11 @@ def onnx_export_opset():
912
opset = getattr(cfg, ATR_NAME)
1013

1114
except:
12-
from torch.onnx._globals import GLOBALS as cfg
15+
if torch_version < version.parse('2.9.0'):
16+
from torch.onnx._globals import GLOBALS as cfg
17+
else:
18+
from torch.onnx._internal.torchscript_exporter._globals import GLOBALS as cfg
19+
1320
ATR_NAME = 'export_onnx_opset_version'
1421
opset = getattr(cfg, ATR_NAME)
1522

0 commit comments

Comments
 (0)