Skip to content

Commit e55ea87

Browse files
committed
add UT
1 parent e712d65 commit e55ea87

File tree

2 files changed

+63
-12
lines changed

2 files changed

+63
-12
lines changed

tests/test_backend.py

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _run_test_case(self, output_names_with_port, feed_dict, **kwargs):
111111
kwargs["convert_var_to_const"] = False
112112
kwargs["constant_fold"] = False
113113
return self.run_test_case(feed_dict, [], output_names_with_port, **kwargs)
114-
'''
114+
115115
def _test_expand_dims_known_rank(self, idx):
116116
tf.reset_default_graph()
117117
x_val = make_xval([3, 4])
@@ -2902,23 +2902,68 @@ def test_unique(self):
29022902
# FIXME: indices in onnx are not the same as in tensorflow so don't check for now
29032903
# self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
29042904
self._run_test_case([_OUTPUT], {_INPUT: x_val})
2905-
'''
29062905

2906+
@check_opset_min_version(10, "Conv2DBackpropInput")
29072907
def test_Conv2DBackpropInput_const(self):
29082908
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
29092909
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
29102910
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2911-
_ = tf.nn.conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val, out_backprop=out_backprop_val, strides=[1,1,1,1], padding='SAME', name=_TFOUTPUT)
2911+
_ = tf.nn.conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val, out_backprop=out_backprop_val,
2912+
strides=[1, 1, 1, 1], padding='SAME', name=_TFOUTPUT)
2913+
self._run_test_case([_OUTPUT], {})
2914+
2915+
@check_opset_min_version(10, "Conv2DBackpropInput")
2916+
def test_Conv2DBackpropInput_const_strided(self):
2917+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
2918+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2919+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
2920+
_ = tf.nn.conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val, out_backprop=out_backprop_val,
2921+
strides=[1, 2, 2, 1], padding='SAME', name=_TFOUTPUT)
2922+
self._run_test_case([_OUTPUT], {})
2923+
2924+
@check_opset_min_version(10, "Conv2DBackpropInput")
2925+
def test_Conv2DBackpropInput_const_valid(self):
2926+
input_sizes_val = np.array([1, 12, 12, 3], dtype=np.int32)
2927+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2928+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2929+
_ = tf.nn.conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val, out_backprop=out_backprop_val,
2930+
strides=[1, 1, 1, 1], padding='VALID', name=_TFOUTPUT)
29122931
self._run_test_case([_OUTPUT], {})
29132932

2933+
@check_opset_min_version(10, "Conv2DBackpropInput")
29142934
def test_Conv2DBackpropInput(self):
29152935
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
29162936
input_sizes = tf.placeholder(tf.int32, input_sizes_val.shape, name=_TFINPUT)
29172937
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
29182938
filter = tf.placeholder(tf.float32, filter_val.shape, name=_TFINPUT1)
29192939
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
29202940
out_backprop = tf.placeholder(tf.float32, out_backprop_val.shape, name=_TFINPUT2)
2921-
_ = tf.nn.conv2d_backprop_input(input_sizes, filter, out_backprop, strides=[1,1,1,1], padding='SAME', name=_TFOUTPUT)
2941+
_ = tf.nn.conv2d_backprop_input(input_sizes, filter, out_backprop, strides=[1, 1, 1, 1], padding='SAME',
2942+
name=_TFOUTPUT)
2943+
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filter_val, _INPUT2: out_backprop_val})
2944+
2945+
@check_opset_min_version(10, "Conv2DBackpropInput")
2946+
def test_Conv2DBackpropInput_strided(self):
2947+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
2948+
input_sizes = tf.placeholder(tf.int32, input_sizes_val.shape, name=_TFINPUT)
2949+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2950+
filter = tf.placeholder(tf.float32, filter_val.shape, name=_TFINPUT1)
2951+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
2952+
out_backprop = tf.placeholder(tf.float32, out_backprop_val.shape, name=_TFINPUT2)
2953+
_ = tf.nn.conv2d_backprop_input(input_sizes, filter, out_backprop, strides=[1, 2, 2, 1], padding='SAME',
2954+
name=_TFOUTPUT)
2955+
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filter_val, _INPUT2: out_backprop_val})
2956+
2957+
@check_opset_min_version(10, "Conv2DBackpropInput")
2958+
def test_Conv2DBackpropInput_valid(self):
2959+
input_sizes_val = np.array([1, 12, 12, 3], dtype=np.int32)
2960+
input_sizes = tf.placeholder(tf.int32, input_sizes_val.shape, name=_TFINPUT)
2961+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2962+
filter = tf.placeholder(tf.float32, filter_val.shape, name=_TFINPUT1)
2963+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2964+
out_backprop = tf.placeholder(tf.float32, out_backprop_val.shape, name=_TFINPUT2)
2965+
_ = tf.nn.conv2d_backprop_input(input_sizes, filter, out_backprop, strides=[1, 1, 1, 1], padding='VALID',
2966+
name=_TFOUTPUT)
29222967
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filter_val, _INPUT2: out_backprop_val})
29232968

29242969

tf2onnx/onnx_opset/nn.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -244,21 +244,27 @@ def version_1(cls, ctx, node, **kwargs):
244244
else:
245245
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
246246
output_shape = ctx.make_node("Shape", [node.output[0]])
247-
output_h = GraphBuilder(ctx).make_slice({"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
248-
output_w = GraphBuilder(ctx).make_slice({"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
249-
expect_h = GraphBuilder(ctx).make_slice({"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
250-
expect_w = GraphBuilder(ctx).make_slice({"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
247+
output_h = GraphBuilder(ctx).make_slice(
248+
{"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
249+
output_w = GraphBuilder(ctx).make_slice(
250+
{"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
251+
expect_h = GraphBuilder(ctx).make_slice(
252+
{"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
253+
expect_w = GraphBuilder(ctx).make_slice(
254+
{"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
251255
diff_h = ctx.make_node("Sub", [output_h, expect_h])
252256
diff_w = ctx.make_node("Sub", [output_w, expect_w])
253257
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
254258
start_h = ctx.make_node("Div", [diff_h.output[0], const_two.output[0]])
255259
start_w = ctx.make_node("Div", [diff_w.output[0], const_two.output[0]])
256260
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
257261
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
258-
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis":0})
259-
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis":0}, name="concat_efgh")
260-
const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"), np.array([1,2], dtype=np.int64))
261-
slice_node = ctx.make_node("Slice", [node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]])
262+
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
263+
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0}, name="concat_efgh")
264+
const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"),
265+
np.array([1, 2], dtype=np.int64))
266+
slice_node = ctx.make_node("Slice",
267+
[node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]])
262268
downstream_nodes = ctx.find_output_consumers(node.output[0])
263269
downstream_nodes.remove(output_shape)
264270
downstream_nodes.remove(slice_node)

0 commit comments

Comments
 (0)