@@ -77,7 +77,8 @@ def insert_identity_op(model, op, as_first_node, approx):
7777@pytest .mark .parametrize ("op" , ["Add" , "Sub" , "Mul" , "Div" , "Identity" ])
7878@pytest .mark .parametrize ("approx" , [False , True ])
7979@pytest .mark .parametrize ("as_first_node" , [False , True ])
80- def test_remove_identity_ops (op , as_first_node , approx ):
80+ @pytest .mark .parametrize ("fork_before_id" , [False , True ])
81+ def test_remove_identity_ops (op , as_first_node , approx , fork_before_id ):
8182 # set up onnx model
8283 inp = helper .make_tensor_value_info ("inp" , TensorProto .FLOAT , [1 , 4 , 1 , 1 ])
8384 mul = helper .make_tensor_value_info ("mul" , TensorProto .FLOAT , [])
@@ -114,14 +115,16 @@ def test_remove_identity_ops(op, as_first_node, approx):
114115 model = model .transform (InferShapes ())
115116 model = model .transform (InferDataTypes ())
116117 idict = {"inp" : inp_values }
117- odict = oxe .execute_onnx (model , idict )
118- out_before = odict ["outp" ]
118+ odict_before = oxe .execute_onnx (model , idict )
119119 num_of_nodes_before = len (model .graph .node )
120-
120+ if fork_before_id and not as_first_node :
121+ divout_vi = model .get_tensor_valueinfo ("div_out" )
122+ model .graph .output .append (divout_vi )
123+ model .graph .value_info .remove (divout_vi )
121124 model = model .transform (RemoveIdentityOps ())
122125 num_of_nodes_after = len (model .graph .node )
123126 assert num_of_nodes_before - 1 == num_of_nodes_after
124127
125- odict = oxe .execute_onnx (model , idict )
126- out_after = odict [ "outp" ]
127- assert np . isclose ( out_before , out_after , atol = 1e-3 ). all ()
128+ odict_after = oxe .execute_onnx (model , idict )
129+ outputs_same = [ np . isclose ( odict_before [ tname ], odict_after [ tname ], atol = 1e-3 ). all () for tname in odict_before . keys () ]
130+ assert all (outputs_same )
0 commit comments