Skip to content

Commit b25b9d4

Browse files
authored
Merge pull request #749 from RandySheriffH/rashuai/ConvTransposeDynamicHW
ConvBackpropInput with dynamic hw
2 parents ef5522d + d99a1b1 commit b25b9d4

File tree

2 files changed

+104
-12
lines changed

2 files changed

+104
-12
lines changed

tests/test_backend.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2901,5 +2901,69 @@ def test_unique(self):
29012901
# self._run_test_case([_OUTPUT, _OUTPUT1], {_INPUT: x_val})
29022902
self._run_test_case([_OUTPUT], {_INPUT: x_val})
29032903

2904+
@check_opset_min_version(10, "Conv2DBackpropInput")
2905+
def test_Conv2DBackpropInput_const(self):
2906+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
2907+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2908+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2909+
_ = tf.nn.conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val, out_backprop=out_backprop_val,
2910+
strides=[1, 1, 1, 1], padding='SAME', name=_TFOUTPUT)
2911+
self._run_test_case([_OUTPUT], {})
2912+
2913+
@check_opset_min_version(10, "Conv2DBackpropInput")
2914+
def test_Conv2DBackpropInput_const_strided(self):
2915+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
2916+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2917+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
2918+
_ = tf.nn.conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val, out_backprop=out_backprop_val,
2919+
strides=[1, 2, 2, 1], padding='SAME', name=_TFOUTPUT)
2920+
self._run_test_case([_OUTPUT], {})
2921+
2922+
@check_opset_min_version(10, "Conv2DBackpropInput")
2923+
def test_Conv2DBackpropInput_const_valid(self):
2924+
input_sizes_val = np.array([1, 12, 12, 3], dtype=np.int32)
2925+
filter_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2926+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2927+
_ = tf.nn.conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val, out_backprop=out_backprop_val,
2928+
strides=[1, 1, 1, 1], padding='VALID', name=_TFOUTPUT)
2929+
self._run_test_case([_OUTPUT], {})
2930+
2931+
@check_opset_min_version(10, "Conv2DBackpropInput")
2932+
def test_Conv2DBackpropInput(self):
2933+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
2934+
input_sizes = tf.placeholder(tf.int32, input_sizes_val.shape, name=_TFINPUT)
2935+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2936+
filters = tf.placeholder(tf.float32, filters_val.shape, name=_TFINPUT1)
2937+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2938+
out_backprop = tf.placeholder(tf.float32, out_backprop_val.shape, name=_TFINPUT2)
2939+
_ = tf.nn.conv2d_backprop_input(input_sizes, filters, out_backprop, strides=[1, 1, 1, 1], padding='SAME',
2940+
name=_TFOUTPUT)
2941+
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
2942+
2943+
@check_opset_min_version(10, "Conv2DBackpropInput")
2944+
def test_Conv2DBackpropInput_strided(self):
2945+
input_sizes_val = np.array([1, 10, 10, 3], dtype=np.int32)
2946+
input_sizes = tf.placeholder(tf.int32, input_sizes_val.shape, name=_TFINPUT)
2947+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2948+
filters = tf.placeholder(tf.float32, filters_val.shape, name=_TFINPUT1)
2949+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
2950+
out_backprop = tf.placeholder(tf.float32, out_backprop_val.shape, name=_TFINPUT2)
2951+
_ = tf.nn.conv2d_backprop_input(input_sizes, filters, out_backprop, strides=[1, 2, 2, 1], padding='SAME',
2952+
name=_TFOUTPUT)
2953+
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
2954+
2955+
@check_opset_min_version(10, "Conv2DBackpropInput")
2956+
def test_Conv2DBackpropInput_valid(self):
2957+
input_sizes_val = np.array([1, 12, 12, 3], dtype=np.int32)
2958+
input_sizes = tf.placeholder(tf.int32, input_sizes_val.shape, name=_TFINPUT)
2959+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
2960+
filters = tf.placeholder(tf.float32, filters_val.shape, name=_TFINPUT1)
2961+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
2962+
out_backprop = tf.placeholder(tf.float32, out_backprop_val.shape, name=_TFINPUT2)
2963+
_ = tf.nn.conv2d_backprop_input(input_sizes, filters, out_backprop, strides=[1, 1, 1, 1], padding='VALID',
2964+
name=_TFOUTPUT)
2965+
self._run_test_case([_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
2966+
2967+
29042968
if __name__ == '__main__':
29052969
unittest_main()

tf2onnx/onnx_opset/nn.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,21 +226,49 @@ def version_1(cls, ctx, node, **kwargs):
226226
# Note: inputs are reversed from what one would expect.
227227
kernel_shape = conv_kernel_shape(ctx, node, 1)
228228
input_shape = ctx.get_shape(node.input[2])
229+
append_slice = False
229230

230231
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
231-
output_shape = ctx.get_shape(node.output[0])
232-
if node.is_nhwc():
233-
new_output_shape = [output_shape[1], output_shape[2]]
234-
input_hw = [input_shape[1], input_shape[2]]
232+
if node.inputs[0].is_const():
233+
output_shape = ctx.get_shape(node.output[0])
234+
if node.is_nhwc():
235+
new_output_shape = [output_shape[1], output_shape[2]]
236+
input_hw = [input_shape[1], input_shape[2]]
237+
else:
238+
new_output_shape = [output_shape[2], output_shape[3]]
239+
input_hw = [input_shape[2], input_shape[3]]
240+
utils.make_sure(new_output_shape.count(-1) <= 0, "output h and w need to be known")
241+
utils.make_sure(new_output_shape[0] >= input_hw[0] and new_output_shape[1] >= input_hw[1],
242+
"output h and w cannot be smaller than input h and w.")
243+
node.set_attr("output_shape", new_output_shape)
235244
else:
236-
new_output_shape = [output_shape[2], output_shape[3]]
237-
input_hw = [input_shape[2], input_shape[3]]
238-
239-
utils.make_sure(new_output_shape.count(-1) <= 0, "output h and w need to be known")
240-
utils.make_sure(new_output_shape[0] >= input_hw[0] and new_output_shape[1] >= input_hw[1],
241-
"output h and w cannot be smaller than input h and w.")
242-
243-
node.set_attr("output_shape", new_output_shape)
245+
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
246+
output_shape = ctx.make_node("Shape", [node.output[0]])
247+
output_h = GraphBuilder(ctx).make_slice(
248+
{"data": output_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
249+
output_w = GraphBuilder(ctx).make_slice(
250+
{"data": output_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
251+
expect_h = GraphBuilder(ctx).make_slice(
252+
{"data": input_shape.output[0], "ends": [2], "starts": [1], "axes": [0]})
253+
expect_w = GraphBuilder(ctx).make_slice(
254+
{"data": input_shape.output[0], "ends": [3], "starts": [2], "axes": [0]})
255+
diff_h = ctx.make_node("Sub", [output_h, expect_h])
256+
diff_w = ctx.make_node("Sub", [output_w, expect_w])
257+
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
258+
start_h = ctx.make_node("Div", [diff_h.output[0], const_two.output[0]])
259+
start_w = ctx.make_node("Div", [diff_w.output[0], const_two.output[0]])
260+
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
261+
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
262+
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
263+
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
264+
const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"),
265+
np.array([1, 2], dtype=np.int64))
266+
slice_node = ctx.make_node("Slice",
267+
[node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]])
268+
downstream_nodes = ctx.find_output_consumers(node.output[0])
269+
downstream_nodes.remove(output_shape)
270+
downstream_nodes.remove(slice_node)
271+
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0])
244272

245273
strides = conv_dims_attr(node, "strides")
246274
conv_dims_attr(node, "dilations")

0 commit comments

Comments
 (0)