|
33 | 33 | import qonnx.core.data_layout as DataLayout |
34 | 34 | from qonnx.core.datatype import DataType |
35 | 35 | from qonnx.core.modelwrapper import ModelWrapper |
36 | | -from qonnx.util.basic import qonnx_make_model |
| 36 | +from qonnx.util.basic import get_preferred_onnx_opset, qonnx_make_model |
37 | 37 |
|
38 | 38 |
|
39 | 39 | def test_modelwrapper(): |
@@ -231,3 +231,31 @@ def test_modelwrapper_set_tensor_shape_multiple_inputs(): |
231 | 231 | # check that order of inputs is preserved |
232 | 232 | assert model.graph.input[0].name == "in1" |
233 | 233 | assert model.graph.input[1].name == "in2" |
| 234 | + |
| 235 | + |
| 236 | +def test_modelwrapper_set_opset_import(): |
| 237 | + # Create a simple model |
| 238 | + in1 = onnx.helper.make_tensor_value_info("in1", onnx.TensorProto.FLOAT, [4, 4]) |
| 239 | + out1 = onnx.helper.make_tensor_value_info("out1", onnx.TensorProto.FLOAT, [4, 4]) |
| 240 | + node = onnx.helper.make_node("Neg", inputs=["in1"], outputs=["out1"]) |
| 241 | + graph = onnx.helper.make_graph( |
| 242 | + nodes=[node], |
| 243 | + name="single_node_graph", |
| 244 | + inputs=[in1], |
| 245 | + outputs=[out1], |
| 246 | + ) |
| 247 | + onnx_model = qonnx_make_model(graph, producer_name="opset-test-model") |
| 248 | + model = ModelWrapper(onnx_model) |
| 249 | + |
| 250 | + # Test setting new domain |
| 251 | + model.set_opset_import("qonnx.custom_op.general", 1) |
| 252 | + preferred_onnx_opset = get_preferred_onnx_opset() |
| 253 | + assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 1} |
| 254 | + |
| 255 | + # Test updating existing domain |
| 256 | + model.set_opset_import("qonnx.custom_op.general", 2) |
| 257 | + assert model.get_opset_imports() == {"": preferred_onnx_opset, "qonnx.custom_op.general": 2} |
| 258 | + |
| 259 | + # Test setting ONNX main domain |
| 260 | + model.set_opset_import("", 13) |
| 261 | + assert model.get_opset_imports() == {"": 13, "qonnx.custom_op.general": 2} |
0 commit comments