|
8 | 8 | from __future__ import unicode_literals
|
9 | 9 |
|
10 | 10 | import numpy as np
|
11 |
| -from onnx import helper, TensorProto |
| 11 | +from onnx import helper, TensorProto, OperatorSetIdProto |
12 | 12 | from tf2onnx import utils
|
13 | 13 | from tf2onnx.graph import GraphUtil
|
14 | 14 | from backend_test_base import Tf2OnnxBackendTestBase
|
15 |
| -from common import unittest_main, group_nodes_by_type |
| 15 | +from common import unittest_main, group_nodes_by_type, check_opset_min_version |
16 | 16 |
|
17 | 17 |
|
18 | 18 | # pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
|
@@ -629,6 +629,35 @@ def test_duplicated_duplicated_attributes(self):
|
629 | 629 | self.run_merge_duplicated_nodes_compare(["OUT"], {"X": np.random.randn(5, 5).astype(np.float32)}, model_proto,
|
630 | 630 | op_type="ReduceSum", remaining_op_num=2)
|
631 | 631 |
|
| 632 | + @check_opset_min_version(9, "Constant") |
| 633 | + def test_duplicated_duplicated_constant(self): |
| 634 | + const_val = np.array([1, 2, 3], dtype=np.float32) |
| 635 | + tensor_1 = helper.make_tensor("tensor_1", TensorProto.FLOAT, const_val.shape, const_val) |
| 636 | + tensor_2 = helper.make_tensor("tensor_2", TensorProto.FLOAT, const_val.shape, const_val) |
| 637 | + tensor_3 = helper.make_tensor("tensor_3", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True) |
| 638 | + tensor_4 = helper.make_tensor("tensor_4", TensorProto.FLOAT, const_val.shape, const_val.tobytes(), raw=True) |
| 639 | + node0 = helper.make_node('Constant', inputs=[], outputs=["value0"], value=tensor_1) |
| 640 | + node1 = helper.make_node('Constant', inputs=[], outputs=["value1"], value=tensor_2) |
| 641 | + node2 = helper.make_node('Constant', inputs=[], outputs=["value2"], value=tensor_3) |
| 642 | + node3 = helper.make_node('Constant', inputs=[], outputs=["value3"], value=tensor_4) |
| 643 | + node4 = helper.make_node("Mul", ["value0", "value1"], ["output1"]) |
| 644 | + node5 = helper.make_node("Mul", ["value2", "output1"], ["output2"]) |
| 645 | + node6 = helper.make_node("Mul", ["value3", "output2"], ["OUT"]) |
| 646 | + |
| 647 | + graph = helper.make_graph( |
| 648 | + [node0, node1, node2, node3, node4, node5, node6], |
| 649 | + "test_duplicated_duplicated_constant", |
| 650 | + [], |
| 651 | + [helper.make_tensor_value_info("OUT", TensorProto.FLOAT, (3,))], |
| 652 | + ) |
| 653 | + |
| 654 | + imp = OperatorSetIdProto() |
| 655 | + imp.version = self.config.opset |
| 656 | + |
| 657 | + model_proto = helper.make_model(graph, producer_name="onnx-tests", opset_imports=[imp]) |
| 658 | + self.run_merge_duplicated_nodes_compare(["OUT"], {}, model_proto, |
| 659 | + op_type="Constant", remaining_op_num=1) |
| 660 | + |
632 | 661 | def test_duplicated_node_is_graph_output(self):
|
633 | 662 | node0 = helper.make_node('Add', inputs=["X", "X"], outputs=["value0"])
|
634 | 663 | node1 = helper.make_node('Add', inputs=["X", "X"], outputs=["value1"])
|
|
0 commit comments