Skip to content

Commit 8114f4b

Browse files
Fix bug in depthwise conv dilations with padding=same (#1320)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 67da5e1 commit 8114f4b

File tree

3 files changed

+19
-6
lines changed

3 files changed

+19
-6
lines changed

tests/backend_test_base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,9 @@ def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeh
182182
tf.import_graph_def(graph_def, name='')
183183
graph_def = tf_optimize(list(feed_dict.keys()), outputs, graph_def, fold_constant=constant_fold)
184184

185-
if True or self.config.is_debug_mode:
186-
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
187-
utils.save_protobuf(model_path, graph_def)
188-
self.logger.debug("created file %s", model_path)
185+
model_path = os.path.join(self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb")
186+
utils.save_protobuf(model_path, graph_def)
187+
self.logger.debug("created file %s", model_path)
189188
return result, graph_def, initialized_tables
190189

191190
def convert_to_tflite(self, graph_def, feed_dict, outputs):

tests/test_backend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,20 @@ def func(x):
609609
# rtol is a bit high, 2 values have a bit high error. Maybe use different input data.
610610
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=0.01)
611611

612+
@check_tf_min_version("1.14", "tf depthwise_conv2d dilations")
613+
@check_opset_min_version(11, "non-const pads")
614+
def test_depthwiseconv_dilations(self):
615+
x_shape = [1, 32, 32, 1]
616+
kernel_shape = [5, 5, 1, 1]
617+
x_val = np.arange(1, 1 + np.prod(x_shape)).astype("float32").reshape(x_shape)
618+
kernel_val = np.arange(1, 1 + np.prod(kernel_shape)).astype("float32").reshape(kernel_shape)
619+
def func(x):
620+
kernel = tf.constant(kernel_val, dtype=tf.float32, name='k')
621+
conv = tf.nn.depthwise_conv2d(x, kernel, strides=[1, 1, 1, 1], padding='SAME', dilations=[3, 4])
622+
return tf.identity(conv, name=_TFOUTPUT)
623+
# rtol is a bit high, 2 values have a bit high error. Maybe use different input data.
624+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=0.01)
625+
612626
@check_tf_max_version("1.15", "not supported in tf-2.0")
613627
def test_dropout(self):
614628
x_val = np.ones([1, 24, 24, 3], dtype=np.float32)

tf2onnx/onnx_opset/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,9 +587,9 @@ def version_1(cls, ctx, node, **kwargs):
587587

588588
node.set_attr("kernel_shape", [k_h, k_w])
589589
strides = conv_dims_attr(node, "strides")
590-
conv_dims_attr(node, "dilations")
590+
dilations = conv_dims_attr(node, "dilations")
591591
node.set_attr("group", k_input_channels)
592-
add_padding(ctx, node, kernel_shape, strides)
592+
add_padding(ctx, node, kernel_shape, strides, dilations)
593593

594594
new_kernel_shape = [k_h, k_w, 1, k_output_channels]
595595
conv_convert_inputs(ctx, node, with_kernel=True, new_kernel_shape=new_kernel_shape)

0 commit comments

Comments
 (0)