Skip to content

Commit 60860ef

Browse files
committed
refactor
1 parent 02fd7bd commit 60860ef

File tree

4 files changed

+4
-10
lines changed

4 files changed

+4
-10
lines changed

tests/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"check_tf_min_version",
2424
"check_tf_max_version",
2525
"skip_tf_versions",
26-
"skip_if_tf_cpu",
26+
"skip_tf_cpu",
2727
"check_onnxruntime_min_version",
2828
"check_opset_min_version",
2929
"check_opset_max_version",
@@ -187,7 +187,7 @@ def is_tf_gpu():
187187
return tf.test.is_gpu_available()
188188

189189

190-
def skip_if_tf_cpu(message=""):
190+
def skip_tf_cpu(message=""):
191191
is_tf_cpu = not is_tf_gpu()
192192
return unittest.skipIf(is_tf_cpu, message)
193193

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def test_maxpool(self):
258258
self.logger.debug(str(p))
259259
self._run_test_case([_OUTPUT], {_INPUT: x_val})
260260

261-
@skip_if_tf_cpu("only tf_gpu can run maxpool with NCHW format")
261+
@skip_tf_cpu("only tf_gpu can run maxpool with NCHW format")
262262
def test_maxpool_gpu(self):
263263
# make sure converter behaves well when data format is NCHW
264264
# and when data format is NCHW, only gpu version of tensorflow can run it.

tf2onnx/graph.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,6 @@ def is_nhwc(self):
139139
"""Return True if node is in NHWC format."""
140140
return self.data_format == "NHWC"
141141

142-
def is_nchw(self):
143-
"""Return True if node is in NCHW format."""
144-
return self.data_format == "NCHW"
145-
146142
def is_const(self):
147143
"""Return True if node is a constant."""
148144
return self.type in ["Const", "ConstV2"]

tf2onnx/onnx_opset/nn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,9 @@ def _convert(cls, ctx, node, **kwargs):
314314
if node.is_nhwc():
315315
kernel_shape_hw = kernel_shape_tf[1:3]
316316
strides_hw = strides_tf[1:3]
317-
elif node.is_nchw():
317+
else:
318318
kernel_shape_hw = kernel_shape_tf[2:4]
319319
strides_hw = strides_tf[2:4]
320-
else:
321-
logger.warning("unexpected data format, please check it")
322320
node.set_attr("kernel_shape", kernel_shape_hw)
323321
node.set_attr("strides", strides_hw)
324322
conv_dims_attr(node, "dilations")

0 commit comments

Comments
 (0)