@@ -51,28 +51,34 @@ def insert_identity_op(model, op, as_first_node, approx):
5151 val = np .asarray ([zero_val ], dtype = np .float32 )
5252 elif op in ["Mul" , "Div" ]:
5353 val = np .asarray ([one_val ], dtype = np .float32 )
54+ elif op in ["Identity" ]:
55+ val = None
5456 else :
5557 return
5658
5759 graph = model .graph
60+ if val is None :
61+ inplist = ["inp" if as_first_node else "div_out" ]
62+ else :
63+ model .set_initializer ("value" , val )
64+ inplist = ["inp" if as_first_node else "div_out" , "value" ]
65+ identity_node = helper .make_node (op , inplist , ["ident_out" ])
5866 if as_first_node :
59- identity_node = helper .make_node (op , ["inp" , "value" ], ["ident_out" ])
6067 graph .node .insert (0 , identity_node )
6168 graph .node [1 ].input [0 ] = "ident_out"
6269 else :
63- identity_node = helper .make_node (op , ["div_out" , "value" ], ["ident_out" ])
6470 graph .node .insert (3 , identity_node )
6571 graph .node [- 1 ].input [0 ] = "ident_out"
66- model .set_initializer ("value" , val )
6772
6873 return model
6974
7075
7176# identity operations to be inserted
72- @pytest .mark .parametrize ("op" , ["Add" , "Sub" , "Mul" , "Div" ])
77+ @pytest .mark .parametrize ("op" , ["Add" , "Sub" , "Mul" , "Div" , "Identity" ])
7378@pytest .mark .parametrize ("approx" , [False , True ])
7479@pytest .mark .parametrize ("as_first_node" , [False , True ])
75- def test_remove_identity_ops (op , as_first_node , approx ):
80+ @pytest .mark .parametrize ("fork_before_id" , [False , True ])
81+ def test_remove_identity_ops (op , as_first_node , approx , fork_before_id ):
7682 # set up onnx model
7783 inp = helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , [1 , 4 , 1 , 1 ])
7884 mul = helper .make_tensor_value_info ("mul" , TensorProto .FLOAT , [])
@@ -109,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx):
109115 model = model .transform (InferShapes ())
110116 model = model .transform (InferDataTypes ())
111117 idict = {"inp" : inp_values }
112- odict = oxe .execute_onnx (model , idict )
113- out_before = odict ["outp" ]
118+ odict_before = oxe .execute_onnx (model , idict )
114119 num_of_nodes_before = len (model .graph .node )
115-
120+ if fork_before_id and not as_first_node :
121+ divout_vi = model .get_tensor_valueinfo ("div_out" )
122+ model .graph .output .append (divout_vi )
123+ model .graph .value_info .remove (divout_vi )
116124 model = model .transform (RemoveIdentityOps ())
117125 num_of_nodes_after = len (model .graph .node )
118126 assert num_of_nodes_before - 1 == num_of_nodes_after
119127
120- odict = oxe .execute_onnx (model , idict )
121- out_after = odict [ "outp" ]
122- assert np . isclose ( out_before , out_after , atol = 1e-3 ). all ()
128+ odict_after = oxe .execute_onnx (model , idict )
129+ outputs_same = [ np . isclose ( odict_before [ tname ], odict_after [ tname ], atol = 1e-3 ). all () for tname in odict_before . keys () ]
130+ assert all (outputs_same )
0 commit comments