Skip to content

Commit 097e28c

Browse files
Merge pull request #1020 from onnx/tom/MaxPool3D
Implemented MaxPool3D and AvgPool3D
2 parents a8b02aa + f34aa95 commit 097e28c

File tree

3 files changed

+94
-17
lines changed

3 files changed

+94
-17
lines changed

tests/test_backend.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,77 @@ def test_conv2d_6(self):
395395
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
396396
self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05)
397397

398+
def test_conv3d_1(self):
399+
strides = [1, 1, 1, 1, 1]
400+
dilations = [1, 1, 1, 1, 1]
401+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
402+
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
403+
padding = "VALID"
404+
def func(x):
405+
kernel = tf.constant(w, dtype=tf.float32, name='k')
406+
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
407+
return tf.identity(conv, name=_TFOUTPUT)
408+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
409+
410+
def test_conv3d_2(self):
411+
strides = [1, 2, 3, 1, 1]
412+
dilations = [1, 1, 1, 1, 1]
413+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
414+
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
415+
padding = "VALID"
416+
def func(x):
417+
kernel = tf.constant(w, dtype=tf.float32, name='k')
418+
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
419+
return tf.identity(conv, name=_TFOUTPUT)
420+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
421+
422+
def test_conv3d_3(self):
423+
strides = [1, 2, 3, 1, 1]
424+
dilations = [1, 1, 1, 1, 1]
425+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
426+
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
427+
padding = "SAME"
428+
def func(x):
429+
kernel = tf.constant(w, dtype=tf.float32, name='k')
430+
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
431+
return tf.identity(conv, name=_TFOUTPUT)
432+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
433+
434+
def test_avgpool3d(self):
435+
strides = [1, 1, 1, 1, 1]
436+
ksize = [1, 2, 2, 3, 1]
437+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
438+
padding = "VALID"
439+
440+
def func(x):
441+
mp = tf.nn.avg_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC")
442+
return tf.identity(mp, name=_TFOUTPUT)
443+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
444+
445+
def test_maxpool3d(self):
446+
strides = [1, 1, 1, 1, 1]
447+
ksize = [1, 2, 2, 3, 1]
448+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
449+
padding = "VALID"
450+
451+
def func(x):
452+
mp = tf.nn.max_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC")
453+
return tf.identity(mp, name=_TFOUTPUT)
454+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
455+
456+
@check_tf_min_version("1.14", "tf.nn.avg_pool2d doesn't exist before tf 1.14")
457+
def test_avgpool2d(self):
458+
strides = [1, 1, 1, 1]
459+
ksize = [1, 2, 3, 1]
460+
x_val = make_xval([2, 10, 12, 3])
461+
padding = "VALID"
462+
463+
def func(x):
464+
mp = tf.nn.avg_pool2d(x, ksize, strides, padding=padding, data_format="NHWC")
465+
return tf.identity(mp, name=_TFOUTPUT)
466+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
467+
468+
398469
@check_tf_min_version("1.7", "tf only support dilation is 1 for now")
399470
def test_conv2d_7(self):
400471
x_shape = [1, 35, 35, 288] # out: [1, 17, 17, 384]

tf2onnx/onnx_opset/nn.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,15 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
254254
else:
255255
raise ValueError("invalid padding value: {}".format(padding))
256256

257+
def parse_dims_attr(node, dims, spatial):
258+
if is_channels_last(node):
259+
# We have (N, ..., C) or (...).
260+
if len(dims) != spatial:
261+
dims = dims[1:-1]
262+
else:
263+
# We have (N, C, ...).
264+
dims = dims[2:]
265+
return dims
257266

258267
def conv_dims_attr(node, name, new_name=None, spatial=2):
259268
# Fetch attribute.
@@ -266,13 +275,7 @@ def conv_dims_attr(node, name, new_name=None, spatial=2):
266275

267276
# Get spatial part.
268277
dims = dims.ints
269-
if is_channels_last(node):
270-
# We have (N, ..., C) or (...).
271-
if len(dims) != spatial:
272-
dims = dims[1:-1]
273-
else:
274-
# We have (N, C, ...).
275-
dims = dims[2:]
278+
dims = parse_dims_attr(node, dims, spatial)
276279

277280
# Set new value and return it.
278281
node.set_attr(new_name, dims)
@@ -475,7 +478,7 @@ def version_1(cls, ctx, node, **kwargs):
475478

476479

477480
@tf_op(["AvgPool", "AvgPool3D"], onnx_op="AveragePool")
478-
@tf_op(["MaxPool", "MaxPoolV2"], onnx_op="MaxPool")
481+
@tf_op(["MaxPool", "MaxPoolV2", "MaxPool3D"], onnx_op="MaxPool")
479482
class PoolOp:
480483
@classmethod
481484
def version_1(cls, ctx, node, **kwargs):
@@ -497,6 +500,11 @@ def _convert(cls, ctx, node, **kwargs):
497500
# @AttrType.INTS strides)
498501
# above seems wrong - input[1] is ksize, input[2] is strides
499502
# stride and ksize in tf is not always NHWC, so watch out when converting into onnx's NCHW
503+
if kwargs["tf_op"] in ["AvgPool3D", "MaxPool3D"]:
504+
spatial = 3
505+
else:
506+
spatial = 2
507+
500508
if len(node.input) < 3:
501509
kernel_shape_tf = node.get_attr("ksize").ints
502510
strides_tf = node.get_attr("strides").ints
@@ -506,17 +514,14 @@ def _convert(cls, ctx, node, **kwargs):
506514
ctx.remove_input(node, node.input[2])
507515
ctx.remove_input(node, node.input[1])
508516

509-
if node.is_nhwc():
510-
kernel_shape_hw = kernel_shape_tf[1:3]
511-
strides_hw = strides_tf[1:3]
512-
else:
513-
kernel_shape_hw = kernel_shape_tf[2:4]
514-
strides_hw = strides_tf[2:4]
517+
kernel_shape_hw = parse_dims_attr(node, kernel_shape_tf, spatial)
518+
strides_hw = parse_dims_attr(node, strides_tf, spatial)
519+
515520
node.set_attr("kernel_shape", kernel_shape_hw)
516521
node.set_attr("strides", strides_hw)
517-
conv_dims_attr(node, "dilations")
518-
add_padding(ctx, node, kernel_shape_hw, strides_hw)
519-
conv_convert_inputs(ctx, node, with_kernel=False)
522+
dilations = conv_dims_attr(node, "dilations", spatial=spatial)
523+
add_padding(ctx, node, kernel_shape_hw, strides_hw, dilations=dilations, spatial=spatial)
524+
conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial)
520525

521526

522527
@tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool")

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def tensorflow_onnx_mapping(g, ops_mapping):
235235
# if there is a onnx_op key we'll map the old type to a new type
236236
onnx_op = kwargs.get("onnx_op")
237237
if onnx_op:
238+
kwargs["tf_op"] = op
238239
node.type = onnx_op
239240
body_graphs = node.get_body_graphs()
240241
if body_graphs:

0 commit comments

Comments
 (0)