Skip to content

Commit a097b41

Browse files
committed
[Core] introduce set_opset_import
1 parent 38b9532 commit a097b41

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

src/qonnx/core/modelwrapper.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,3 +759,23 @@ def get_customop_wrapper(self, node, fallback_customop_version=util.get_preferre
759759
f"using fallback_customop_version={fallback_customop_version}"
760760
)
761761
return getCustomOp(node, onnx_opset_version=fallback_customop_version)
762+
763+
def set_opset_import(self, domain, version):
764+
"""Sets the opset version for a given domain in the model's opset imports.
765+
If the domain already exists, its version will be updated. If not, a new
766+
opset import will be added.
767+
768+
Args:
769+
domain (str): The domain name (e.g. "qonnx.custom_op.general")
770+
version (int): The opset version number
771+
"""
772+
# find if domain already exists in opset imports
773+
for opset in self._model_proto.opset_import:
774+
if opset.domain == domain:
775+
opset.version = version
776+
return
777+
# domain not found, add new opset import
778+
new_opset = onnx.OperatorSetIdProto()
779+
new_opset.domain = domain
780+
new_opset.version = version
781+
self._model_proto.opset_import.append(new_opset)

tests/core/test_modelwrapper.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import qonnx.core.data_layout as DataLayout
3434
from qonnx.core.datatype import DataType
3535
from qonnx.core.modelwrapper import ModelWrapper
36-
from qonnx.util.basic import qonnx_make_model
36+
from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model
3737

3838

3939
def test_modelwrapper():
@@ -231,3 +231,31 @@ def test_modelwrapper_set_tensor_shape_multiple_inputs():
231231
# check that order of inputs is preserved
232232
assert model.graph.input[0].name == "in1"
233233
assert model.graph.input[1].name == "in2"
234+
235+
236+
def test_modelwrapper_set_opset_import():
237+
# Create a simple model
238+
in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4])
239+
out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4])
240+
node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["out1"])
241+
graph = onnx.helper.make_graph(
242+
nodes=[node],
243+
name="single_node_graph",
244+
inputs=[in1],
245+
outputs=[out1],
246+
)
247+
onnx_model = qonnx_make_model(graph, producer_name="opset-test-model")
248+
model = ModelWrapper(onnx_model)
249+
250+
# Test setting new domain
251+
model.set_opset_import("qonnx.custom_op.general", 1)
252+
preferred_onnx_opset = get_preferred_onnx_opset()
253+
assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 1}
254+
255+
# Test updating existing domain
256+
model.set_opset_import("qonnx.custom_op.general", 2)
257+
assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 2}
258+
259+
# Test setting ONNX main domain
260+
model.set_opset_import("", 13)
261+
assert model.get_opset_imports() == {"": 13, "qonnx.custom_op.general": 2}

0 commit comments

Comments
 (0)