Skip to content

Commit f34aa95

Browse files
Implemented MaxPool3D and AvgPool3D
1 parent 85bca92 commit f34aa95

File tree

3 files changed

+94
-17
lines changed

3 files changed

+94
-17
lines changed

tests/test_backend.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,77 @@ def test_conv2d_6(self):
392392
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
393393
self._conv_test(x_val, kernel_val, strides=strides, padding="VALID", rtol=1e-05)
394394

395+
def test_conv3d_1(self):
396+
strides = [1, 1, 1, 1, 1]
397+
dilations = [1, 1, 1, 1, 1]
398+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
399+
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
400+
padding = "VALID"
401+
def func(x):
402+
kernel = tf.constant(w, dtype=tf.float32, name='k')
403+
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
404+
return tf.identity(conv, name=_TFOUTPUT)
405+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
406+
407+
def test_conv3d_2(self):
408+
strides = [1, 2, 3, 1, 1]
409+
dilations = [1, 1, 1, 1, 1]
410+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
411+
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
412+
padding = "VALID"
413+
def func(x):
414+
kernel = tf.constant(w, dtype=tf.float32, name='k')
415+
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
416+
return tf.identity(conv, name=_TFOUTPUT)
417+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
418+
419+
def test_conv3d_3(self):
420+
strides = [1, 2, 3, 1, 1]
421+
dilations = [1, 1, 1, 1, 1]
422+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
423+
w = np.random.random_sample([2, 3, 4, 5, 6]).astype(np.float32)
424+
padding = "SAME"
425+
def func(x):
426+
kernel = tf.constant(w, dtype=tf.float32, name='k')
427+
conv = tf.nn.conv3d(x, kernel, strides=strides, padding=padding, data_format="NDHWC", dilations=dilations)
428+
return tf.identity(conv, name=_TFOUTPUT)
429+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-05)
430+
431+
def test_avgpool3d(self):
432+
strides = [1, 1, 1, 1, 1]
433+
ksize = [1, 2, 2, 3, 1]
434+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
435+
padding = "VALID"
436+
437+
def func(x):
438+
mp = tf.nn.avg_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC")
439+
return tf.identity(mp, name=_TFOUTPUT)
440+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
441+
442+
def test_maxpool3d(self):
443+
strides = [1, 1, 1, 1, 1]
444+
ksize = [1, 2, 2, 3, 1]
445+
x_val = np.random.random_sample([2, 10, 9, 8, 5]).astype(np.float32)
446+
padding = "VALID"
447+
448+
def func(x):
449+
mp = tf.nn.max_pool3d(x, ksize, strides, padding=padding, data_format="NDHWC")
450+
return tf.identity(mp, name=_TFOUTPUT)
451+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
452+
453+
@check_tf_min_version("1.14", "tf.nn.avg_pool2d doesn't exist before tf 1.14")
454+
def test_avgpool2d(self):
455+
strides = [1, 1, 1, 1]
456+
ksize = [1, 2, 3, 1]
457+
x_val = make_xval([2, 10, 12, 3])
458+
padding = "VALID"
459+
460+
def func(x):
461+
mp = tf.nn.avg_pool2d(x, ksize, strides, padding=padding, data_format="NHWC")
462+
return tf.identity(mp, name=_TFOUTPUT)
463+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
464+
465+
395466
@check_tf_min_version("1.7", "tf only support dilation is 1 for now")
396467
def test_conv2d_7(self):
397468
x_shape = [1, 35, 35, 288] # out: [1, 17, 17, 384]

tf2onnx/onnx_opset/nn.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,15 @@ def add_padding(ctx, node, kernel_shape, strides, dilations=None, spatial=2):
254254
else:
255255
raise ValueError("invalid padding value: {}".format(padding))
256256

257+
def parse_dims_attr(node, dims, spatial):
258+
if is_channels_last(node):
259+
# We have (N, ..., C) or (...).
260+
if len(dims) != spatial:
261+
dims = dims[1:-1]
262+
else:
263+
# We have (N, C, ...).
264+
dims = dims[2:]
265+
return dims
257266

258267
def conv_dims_attr(node, name, new_name=None, spatial=2):
259268
# Fetch attribute.
@@ -266,13 +275,7 @@ def conv_dims_attr(node, name, new_name=None, spatial=2):
266275

267276
# Get spatial part.
268277
dims = dims.ints
269-
if is_channels_last(node):
270-
# We have (N, ..., C) or (...).
271-
if len(dims) != spatial:
272-
dims = dims[1:-1]
273-
else:
274-
# We have (N, C, ...).
275-
dims = dims[2:]
278+
dims = parse_dims_attr(node, dims, spatial)
276279

277280
# Set new value and return it.
278281
node.set_attr(new_name, dims)
@@ -475,7 +478,7 @@ def version_1(cls, ctx, node, **kwargs):
475478

476479

477480
@tf_op(["AvgPool", "AvgPool3D"], onnx_op="AveragePool")
478-
@tf_op(["MaxPool", "MaxPoolV2"], onnx_op="MaxPool")
481+
@tf_op(["MaxPool", "MaxPoolV2", "MaxPool3D"], onnx_op="MaxPool")
479482
class PoolOp:
480483
@classmethod
481484
def version_1(cls, ctx, node, **kwargs):
@@ -497,6 +500,11 @@ def _convert(cls, ctx, node, **kwargs):
497500
# @AttrType.INTS strides)
498501
# above seems wrong - input[1] is ksize, input[2] is strides
499502
# stride and ksize in tf is not always NHWC, so watch out when converting into onnx's NCHW
503+
if kwargs["tf_op"] in ["AvgPool3D", "MaxPool3D"]:
504+
spatial = 3
505+
else:
506+
spatial = 2
507+
500508
if len(node.input) < 3:
501509
kernel_shape_tf = node.get_attr("ksize").ints
502510
strides_tf = node.get_attr("strides").ints
@@ -506,17 +514,14 @@ def _convert(cls, ctx, node, **kwargs):
506514
ctx.remove_input(node, node.input[2])
507515
ctx.remove_input(node, node.input[1])
508516

509-
if node.is_nhwc():
510-
kernel_shape_hw = kernel_shape_tf[1:3]
511-
strides_hw = strides_tf[1:3]
512-
else:
513-
kernel_shape_hw = kernel_shape_tf[2:4]
514-
strides_hw = strides_tf[2:4]
517+
kernel_shape_hw = parse_dims_attr(node, kernel_shape_tf, spatial)
518+
strides_hw = parse_dims_attr(node, strides_tf, spatial)
519+
515520
node.set_attr("kernel_shape", kernel_shape_hw)
516521
node.set_attr("strides", strides_hw)
517-
conv_dims_attr(node, "dilations")
518-
add_padding(ctx, node, kernel_shape_hw, strides_hw)
519-
conv_convert_inputs(ctx, node, with_kernel=False)
522+
dilations = conv_dims_attr(node, "dilations", spatial=spatial)
523+
add_padding(ctx, node, kernel_shape_hw, strides_hw, dilations=dilations, spatial=spatial)
524+
conv_convert_inputs(ctx, node, with_kernel=False, spatial=spatial)
520525

521526

522527
@tf_op(["MaxPoolWithArgmax"], onnx_op="MaxPool")

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ def tensorflow_onnx_mapping(g, ops_mapping):
235235
# if there is a onnx_op key we'll map the old type to a new type
236236
onnx_op = kwargs.get("onnx_op")
237237
if onnx_op:
238+
kwargs["tf_op"] = op
238239
node.type = onnx_op
239240
body_graphs = node.get_body_graphs()
240241
if body_graphs:

0 commit comments

Comments
 (0)