12
12
from tf2onnx import utils
13
13
from tf2onnx .graph import GraphUtil
14
14
from backend_test_base import Tf2OnnxBackendTestBase
15
- from common import unittest_main , group_nodes_by_type , check_opset_min_version
15
+ from common import unittest_main , group_nodes_by_type
16
16
17
17
18
18
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -758,9 +758,12 @@ def test_identity_in_subgraph_non_graph_output(self):
758
758
# Merge Duplicated Nodes Optimizer Tests Start
759
759
760
760
def run_merge_duplicated_nodes_compare (self , output_names_with_port , onnx_feed_dict , origin_proto ,
761
- op_type = None , remaining_op_num = None , debug = False , rtol = 1e-07 ):
762
- self .run_and_compare (output_names_with_port , onnx_feed_dict , origin_proto , op_type = op_type ,
763
- remaining_op_num = remaining_op_num , debug = debug , rtol = rtol )
761
+ op_type = None , remaining_op_num = None , debug = False , rtol = 1e-07 ,
762
+ graph_validator = None ):
763
+ new_proto = self .run_and_compare (output_names_with_port , onnx_feed_dict , origin_proto , op_type = op_type ,
764
+ remaining_op_num = remaining_op_num , debug = debug , rtol = rtol )
765
+ if graph_validator :
766
+ self .assertTrue (graph_validator (new_proto .graph ))
764
767
765
768
def test_duplicated_duplicated_input (self ):
766
769
# same input or not
@@ -800,7 +803,10 @@ def test_duplicated_duplicated_attributes(self):
800
803
self .run_merge_duplicated_nodes_compare (["OUT" ], {"X" : np .random .randn (5 , 5 ).astype (np .float32 )}, model_proto ,
801
804
op_type = "ReduceSum" , remaining_op_num = 2 )
802
805
803
- @check_opset_min_version (9 , "Constant" )
806
+ def _check_initializer_num (self , graph_proto , num ):
807
+ print (len (graph_proto .initializer ))
808
+ return num == len (graph_proto .initializer )
809
+
804
810
def test_duplicated_duplicated_constant (self ):
805
811
const_val = np .array ([1 , 2 , 3 ], dtype = np .float32 )
806
812
tensor_1 = helper .make_tensor ("tensor_1" , TensorProto .FLOAT , const_val .shape , const_val )
@@ -826,8 +832,35 @@ def test_duplicated_duplicated_constant(self):
826
832
imp .version = self .config .opset
827
833
828
834
model_proto = helper .make_model (graph , producer_name = "onnx-tests" , opset_imports = [imp ])
829
- self .run_merge_duplicated_nodes_compare (["OUT" ], {}, model_proto ,
830
- op_type = "Constant" , remaining_op_num = 1 )
835
+ self .run_merge_duplicated_nodes_compare (["OUT" ], {}, model_proto , op_type = "Constant" , remaining_op_num = 0 ,
836
+ graph_validator = lambda g : self ._check_initializer_num (g , 1 ))
837
+
838
+ def test_duplicated_duplicated_constant_and_initializer (self ):
839
+ const_val = np .array ([1 , 2 , 3 ], dtype = np .float32 )
840
+ tensor_1 = helper .make_tensor ("value0" , TensorProto .FLOAT , const_val .shape , const_val )
841
+ tensor_2 = helper .make_tensor ("value1" , TensorProto .FLOAT , const_val .shape , const_val )
842
+ tensor_3 = helper .make_tensor ("value2" , TensorProto .FLOAT , const_val .shape , const_val .tobytes (), raw = True )
843
+ tensor_4 = helper .make_tensor ("value3" , TensorProto .FLOAT , const_val .shape , const_val .tobytes (), raw = True )
844
+ node0 = helper .make_node ('Constant' , inputs = [], outputs = ["value0" ], value = tensor_1 )
845
+ node1 = helper .make_node ('Constant' , inputs = [], outputs = ["value1" ], value = tensor_2 )
846
+ node4 = helper .make_node ("Mul" , ["value0" , "value1" ], ["output1" ])
847
+ node5 = helper .make_node ("Mul" , ["value2" , "output1" ], ["output2" ])
848
+ node6 = helper .make_node ("Mul" , ["value3" , "output2" ], ["OUT" ])
849
+
850
+ graph = helper .make_graph (
851
+ [node0 , node1 , node4 , node5 , node6 ],
852
+ "test_duplicated_duplicated_constant" ,
853
+ [helper .make_tensor_value_info ("value2" , TensorProto .FLOAT , (3 ,))],
854
+ [helper .make_tensor_value_info ("OUT" , TensorProto .FLOAT , (3 ,))],
855
+ [tensor_3 , tensor_4 ]
856
+ )
857
+
858
+ imp = OperatorSetIdProto ()
859
+ imp .version = self .config .opset
860
+
861
+ model_proto = helper .make_model (graph , producer_name = "onnx-tests" , opset_imports = [imp ])
862
+ self .run_merge_duplicated_nodes_compare (["OUT" ], {}, model_proto , op_type = "Constant" , remaining_op_num = 0 ,
863
+ graph_validator = lambda g : self ._check_initializer_num (g , 2 ))
831
864
832
865
def test_duplicated_node_is_graph_output (self ):
833
866
node0 = helper .make_node ('Add' , inputs = ["X" , "X" ], outputs = ["value0" ])
0 commit comments