Skip to content

Commit d65b7d2

Browse files
committed
[Core] opset ver. fallback for ModelWrapper.get_customop_wrapper
1 parent 8dbbe05 commit d65b7d2

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -743,9 +743,19 @@ def get_opset_imports(self):
743743
"""Returns a list of imported opsets as a {domain, version} dictionary."""
744744
return {opset.domain: opset.version for opset in self._model_proto.opset_import}
745745

746-
def get_customop_wrapper(self, node):
746+
def get_customop_wrapper(self, node, fallback_customop_version=1):
747747
"""Return CustomOp instance for given node, respecting the
748-
imported opset version in the model protobuf."""
748+
imported opset version in the model protobuf. If the node's domain
749+
is not found in the model's opset imports, fallback_customop_version
750+
will be used."""
749751
opset_imports = self.get_opset_imports()
750-
opset_import = opset_imports[node.domain]
751-
return getCustomOp(node, onnx_opset_version=opset_import)
752+
try:
753+
opset_import = opset_imports[node.domain]
754+
return getCustomOp(node, onnx_opset_version=opset_import)
755+
except KeyError:
756+
# domain not found in imports, use fallback version
757+
warnings.warn(
758+
f"Domain {node.domain} not found in model opset imports, "
759+
f"using fallback_customop_version={fallback_customop_version}"
760+
)
761+
return getCustomOp(node, onnx_opset_version=fallback_customop_version)

0 commit comments

Comments
 (0)