Skip to content

Commit 55e6d96

Browse files
Merge pull request #1041 from onnx/tom/Conv3DBackpropInputV2
Added support for Conv3DBackpropInputV2
2 parents 54ad341 + 1da0e41 commit 55e6d96

File tree

3 files changed

+128
-36
lines changed

3 files changed

+128
-36
lines changed

tests/test_backend.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454

5555
if is_tf2():
5656
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
57+
conv3d_transpose = tf.compat.v1.nn.conv3d_transpose
5758
multinomial = tf.compat.v1.random.multinomial
5859
space_to_batch_nd = tf.compat.v1.space_to_batch_nd
5960
batch_to_space_nd = tf.compat.v1.batch_to_space_nd
@@ -73,6 +74,7 @@
7374
fake_quant_with_min_max_args = tf.quantization.fake_quant_with_min_max_args
7475
elif LooseVersion(tf.__version__) >= "1.13":
7576
conv2d_backprop_input = tf.compat.v1.nn.conv2d_backprop_input
77+
conv3d_transpose = tf.compat.v1.nn.conv3d_transpose
7678
multinomial = tf.compat.v1.random.multinomial
7779
space_to_batch_nd = tf.compat.v1.space_to_batch_nd
7880
batch_to_space_nd = tf.compat.v1.batch_to_space_nd
@@ -93,6 +95,7 @@
9395
fake_quant_with_min_max_args = tf.compat.v1.quantization.fake_quant_with_min_max_args
9496
else:
9597
conv2d_backprop_input = tf.nn.conv2d_backprop_input
98+
conv3d_transpose = tf.nn.conv3d_transpose
9699
multinomial = tf.multinomial
97100
space_to_batch_nd = tf.space_to_batch_nd
98101
batch_to_space_nd = tf.batch_to_space_nd
@@ -3136,45 +3139,38 @@ def func(x):
31363139
@check_opset_min_version(10, "Conv2DBackpropInput")
31373140
def test_Conv2DBackpropInput_const(self):
31383141
input_sizes_val_ = np.array([1, 10, 10, 3], dtype=np.int32)
3139-
filter_val_ = np.random.randint(low=0, high=256, size=[3, 3, 3, 5])
3140-
out_backprop_val_ = np.random.randint(low=0, high=256, size=[1, 10, 10, 5])
3141-
def func():
3142+
def func(filter_val, out_backprop_val):
31423143
input_sizes_val = tf.constant(input_sizes_val_, dtype=tf.int32)
3143-
filter_val = tf.constant(filter_val_, dtype=tf.float32)
3144-
out_backprop_val = tf.constant(out_backprop_val_, dtype=tf.float32)
31453144
return conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val,
31463145
out_backprop=out_backprop_val, strides=[1, 1, 1, 1],
31473146
padding='SAME', name=_TFOUTPUT)
3148-
self._run_test_case(func, [_OUTPUT], {})
3147+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3148+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
3149+
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val})
31493150

31503151
@check_opset_min_version(10, "Conv2DBackpropInput")
31513152
def test_Conv2DBackpropInput_const_strided(self):
31523153
input_sizes_val_ = np.array([1, 10, 10, 3], dtype=np.int32)
3153-
filter_val_ = np.random.randint(low=0, high=256, size=[3, 3, 3, 5])
3154-
out_backprop_val_ = np.random.randint(low=0, high=256, size=[1, 5, 5, 5])
3155-
3156-
def func():
3154+
def func(filter_val, out_backprop_val):
31573155
input_sizes_val = tf.constant(input_sizes_val_, dtype=tf.int32)
3158-
filter_val = tf.constant(filter_val_, dtype=tf.float32)
3159-
out_backprop_val = tf.constant(out_backprop_val_, dtype=tf.float32)
31603156
return conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val,
31613157
out_backprop=out_backprop_val, strides=[1, 2, 2, 1],
31623158
padding='SAME', name=_TFOUTPUT)
3163-
self._run_test_case(func, [_OUTPUT], {})
3159+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3160+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5]).astype(np.float32)
3161+
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val})
31643162

31653163
@check_opset_min_version(10, "Conv2DBackpropInput")
31663164
def test_Conv2DBackpropInput_const_valid(self):
31673165
input_sizes_val_ = np.array([1, 12, 12, 3], dtype=np.int32)
3168-
filter_val_ = np.random.randint(low=0, high=256, size=[3, 3, 3, 5])
3169-
out_backprop_val_ = np.random.randint(low=0, high=256, size=[1, 10, 10, 5])
3170-
def func():
3166+
def func(filter_val, out_backprop_val):
31713167
input_sizes_val = tf.constant(input_sizes_val_, dtype=tf.int32)
3172-
filter_val = tf.constant(filter_val_, dtype=tf.float32)
3173-
out_backprop_val = tf.constant(out_backprop_val_, dtype=tf.float32)
31743168
return conv2d_backprop_input(input_sizes=input_sizes_val, filter=filter_val,
31753169
out_backprop=out_backprop_val, strides=[1, 1, 1, 1],
31763170
padding='VALID', name=_TFOUTPUT)
3177-
self._run_test_case(func, [_OUTPUT], {})
3171+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 5]).astype(np.float32)
3172+
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
3173+
self._run_test_case(func, [_OUTPUT], {_INPUT: filters_val, _INPUT1: out_backprop_val})
31783174

31793175
@check_opset_min_version(10, "Conv2DBackpropInput")
31803176
def test_Conv2DBackpropInput(self):
@@ -3206,6 +3202,72 @@ def func(input_sizes, filters, out_backprop):
32063202
out_backprop_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 5]).astype(np.float32)
32073203
self._run_test_case(func, [_OUTPUT], {_INPUT: input_sizes_val, _INPUT1: filters_val, _INPUT2: out_backprop_val})
32083204

3205+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3206+
def test_Conv3DBackpropInputV2_const(self):
3207+
output_shape_val_ = np.array([1, 10, 10, 10, 3], dtype=np.int32)
3208+
def func(value, filters):
3209+
output_shape_val = tf.constant(output_shape_val_, dtype=tf.int32)
3210+
return conv3d_transpose(value, filters, output_shape_val, strides=[1, 1, 1, 1, 1],
3211+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3212+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3213+
value_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 10, 5]).astype(np.float32)
3214+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val}, rtol=1e-6)
3215+
3216+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3217+
def test_Conv3DBackpropInputV2_const_strided(self):
3218+
output_shape_val_ = np.array([1, 10, 10, 10, 3], dtype=np.int32)
3219+
def func(value, filters):
3220+
output_shape_val = tf.constant(output_shape_val_, dtype=tf.int32)
3221+
return conv3d_transpose(value, filters, output_shape_val, strides=[1, 2, 2, 2, 1],
3222+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3223+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3224+
value_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5, 5]).astype(np.float32)
3225+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val}, rtol=1e-6)
3226+
3227+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3228+
def test_Conv3DBackpropInputV2_const_valid(self):
3229+
output_shape_val_ = np.array([1, 12, 12, 12, 3], dtype=np.int32)
3230+
def func(value, filters):
3231+
output_shape_val = tf.constant(output_shape_val_, dtype=tf.int32)
3232+
return conv3d_transpose(value, filters, output_shape_val, strides=[1, 1, 1, 1, 1],
3233+
padding='VALID', data_format="NDHWC", name=_TFOUTPUT)
3234+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3235+
value_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 10, 5]).astype(np.float32)
3236+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val}, rtol=1e-6)
3237+
3238+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3239+
def test_Conv3DBackpropInputV2(self):
3240+
def func(value, filters, output_shape):
3241+
return conv3d_transpose(value, filters, output_shape, strides=[1, 1, 1, 1, 1],
3242+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3243+
filters_val = np.random.randint(low=0, high=256, size=[2, 3, 4, 4, 5]).astype(np.float32)
3244+
value_val = np.random.randint(low=0, high=256, size=[2, 7, 8, 9, 5]).astype(np.float32)
3245+
output_shape_val = np.array([2, 7, 8, 9, 4], dtype=np.int32)
3246+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val, _INPUT2: output_shape_val},
3247+
rtol=1e-6)
3248+
3249+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3250+
def test_Conv3DBackpropInputV2_strided(self):
3251+
def func(value, filters, output_shape):
3252+
return conv3d_transpose(value, filters, output_shape, strides=[1, 2, 2, 2, 1],
3253+
padding='SAME', data_format="NDHWC", name=_TFOUTPUT)
3254+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3255+
value_val = np.random.randint(low=0, high=256, size=[1, 5, 5, 5, 5]).astype(np.float32)
3256+
output_shape_val = np.array([1, 10, 10, 10, 3], dtype=np.int32)
3257+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val, _INPUT2: output_shape_val},
3258+
rtol=1e-6)
3259+
3260+
@check_opset_min_version(10, "Conv3DBackpropInputV2")
3261+
def test_Conv3DBackpropInputV2_valid(self):
3262+
def func(value, filters, output_shape):
3263+
return conv3d_transpose(value, filters, output_shape, strides=[1, 1, 1, 1, 1],
3264+
padding='VALID', data_format="NDHWC", name=_TFOUTPUT)
3265+
filters_val = np.random.randint(low=0, high=256, size=[3, 3, 3, 3, 5]).astype(np.float32)
3266+
value_val = np.random.randint(low=0, high=256, size=[1, 10, 10, 10, 5]).astype(np.float32)
3267+
output_shape_val = np.array([1, 12, 12, 12, 3], dtype=np.int32)
3268+
self._run_test_case(func, [_OUTPUT], {_INPUT: value_val, _INPUT1: filters_val, _INPUT2: output_shape_val},
3269+
rtol=1e-6)
3270+
32093271
@check_opset_min_version(8, "CategoryMapper")
32103272
@skip_tf2()
32113273
def test_hashtable_lookup(self):

tf2onnx/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ def data_format(self, val):
146146

147147
def is_nhwc(self):
148148
"""Return True if node is in NHWC format."""
149+
utils.make_sure('D' not in self.data_format, "is_nhwc called on %s with spatial=2 but data_format=%s",
150+
self.name, self.data_format)
149151
return self.data_format == "NHWC"
150152

151153
def is_const(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 45 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,7 @@ def version_11(cls, ctx, node, **kwargs):
362362
# No change.
363363
cls.version_1(ctx, node, **kwargs)
364364

365-
366-
@tf_op("Conv2DBackpropInput")
365+
@tf_op(["Conv2DBackpropInput", "Conv3DBackpropInputV2"])
367366
class ConvTranspose:
368367
@classmethod
369368
def version_1(cls, ctx, node, **kwargs):
@@ -372,24 +371,36 @@ def version_1(cls, ctx, node, **kwargs):
372371
# T Y = ConvTranspose(T X, T W, T B, @STRING auto_pad, @INTS dilations,
373372
# @INT group, @INTS kernel_shape, @INTS output_shape, @INTS pads, @INTS strides)
374373

374+
if node.type == "Conv3DBackpropInputV2":
375+
spatial = 3
376+
else:
377+
spatial = 2
375378
node.type = "ConvTranspose"
376379
# Note: inputs are reversed from what one would expect.
377-
conv_kernel_shape(ctx, node, 1)
380+
conv_kernel_shape(ctx, node, 1, spatial=spatial)
378381
input_shape = ctx.get_shape(node.input[2])
379382
output_shape_orig = node.output_shapes
380383

381384
# ouput_shape is explicitly specified here, in this case pads values are auto generated/calculated.
382385
if node.inputs[0].is_const():
383386
output_shape = ctx.get_shape(node.output[0])
384-
if node.is_nhwc():
387+
if is_channels_last(node):
385388
new_output_shape = [output_shape[1], output_shape[2]]
386-
input_hw = [input_shape[1], input_shape[2]]
389+
input_dims = [input_shape[1], input_shape[2]]
390+
if spatial == 3:
391+
new_output_shape.append(output_shape[3])
392+
input_dims.append(input_shape[3])
387393
else:
388394
new_output_shape = [output_shape[2], output_shape[3]]
389-
input_hw = [input_shape[2], input_shape[3]]
390-
utils.make_sure(new_output_shape.count(-1) <= 0, "output h and w need to be known")
391-
utils.make_sure(new_output_shape[0] >= input_hw[0] and new_output_shape[1] >= input_hw[1],
392-
"output h and w cannot be smaller than input h and w.")
395+
input_dims = [input_shape[2], input_shape[3]]
396+
if spatial == 3:
397+
new_output_shape.append(output_shape[4])
398+
input_dims.append(input_shape[4])
399+
400+
utils.make_sure(new_output_shape.count(-1) <= 0, "output dims need to be known")
401+
utils.make_sure(all(new_output_shape[i] >= input_dims[i] for i in range(spatial)),
402+
"output dims cannot be smaller than input dims.")
403+
393404
node.set_attr("output_shape", new_output_shape)
394405
else:
395406
input_shape = ctx.make_node("Cast", [node.input[0]], attr={'to': TensorProto.INT64})
@@ -409,20 +420,37 @@ def version_1(cls, ctx, node, **kwargs):
409420
start_w = ctx.make_node("Div", [diff_w.output[0], const_two.output[0]])
410421
end_h = ctx.make_node("Add", [start_h.output[0], expect_h])
411422
end_w = ctx.make_node("Add", [start_w.output[0], expect_w])
412-
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
413-
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
414-
const_one_two = ctx.make_const(utils.make_name(node.name + "_const_one_two"),
415-
np.array([1, 2], dtype=np.int64))
423+
if spatial == 3:
424+
output_d = GraphBuilder(ctx).make_slice(
425+
{"data": output_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
426+
expect_d = GraphBuilder(ctx).make_slice(
427+
{"data": input_shape.output[0], "ends": [4], "starts": [3], "axes": [0]})
428+
diff_d = ctx.make_node("Sub", [output_d, expect_d])
429+
start_d = ctx.make_node("Div", [diff_d.output[0], const_two.output[0]])
430+
end_d = ctx.make_node("Add", [start_d.output[0], expect_d])
431+
432+
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0], start_d.output[0]],
433+
attr={"axis": 0})
434+
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0], end_d.output[0]], attr={"axis": 0})
435+
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
436+
np.array([1, 2, 3], dtype=np.int64))
437+
else:
438+
starts = ctx.make_node("Concat", [start_h.output[0], start_w.output[0]], attr={"axis": 0})
439+
ends = ctx.make_node("Concat", [end_h.output[0], end_w.output[0]], attr={"axis": 0})
440+
slice_axes = ctx.make_const(utils.make_name(node.name + "_const_slice_axes"),
441+
np.array([1, 2], dtype=np.int64))
442+
416443
slice_node = ctx.make_node("Slice",
417-
[node.output[0], starts.output[0], ends.output[0], const_one_two.output[0]],
444+
[node.output[0], starts.output[0], ends.output[0], slice_axes.output[0]],
418445
shapes=output_shape_orig)
446+
419447
downstream_nodes = ctx.find_output_consumers(node.output[0])
420448
downstream_nodes.remove(output_shape)
421449
downstream_nodes.remove(slice_node)
422450
ctx.replace_all_inputs(downstream_nodes, node.output[0], slice_node.output[0])
423451

424-
conv_dims_attr(node, "strides")
425-
conv_dims_attr(node, "dilations")
452+
conv_dims_attr(node, "strides", spatial=spatial)
453+
conv_dims_attr(node, "dilations", spatial=spatial)
426454

427455
# remove output_shapes input
428456
ctx.remove_input(node, node.input[0])
@@ -431,7 +459,7 @@ def version_1(cls, ctx, node, **kwargs):
431459
node.input[0] = node.input[1]
432460
node.input[1] = t
433461

434-
conv_convert_inputs(ctx, node, with_kernel=True)
462+
conv_convert_inputs(ctx, node, with_kernel=True, spatial=spatial)
435463

436464
@classmethod
437465
def version_11(cls, ctx, node, **kwargs):

0 commit comments

Comments
 (0)