Skip to content

Commit 480d30a

Browse files
committed
add test for data format
1 parent 65ae24a commit 480d30a

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

tests/test_internals.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import tensorflow as tf
1919
from tf2onnx import utils
2020
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
21-
from tf2onnx.graph import GraphUtil
21+
from tf2onnx.graph import Node, GraphUtil
2222
from common import unittest_main
2323

2424

@@ -203,6 +203,21 @@ def test_shape_utils(self):
203203
self.assertFalse(utils.are_shapes_equal(None, []))
204204
self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
205205

206+
def test_data_format(self):
207+
n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", data_format="NHWC")
208+
graph_proto = helper.make_graph(
209+
nodes=[n1],
210+
name="test",
211+
inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]),
212+
helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])],
213+
outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])],
214+
initializer=[]
215+
)
216+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
217+
n = g.get_node_by_name("n1")
218+
self.assertEqual(n.data_format, "NHWC")
219+
self.assertTrue(n.is_nhwc())
220+
206221

207222
if __name__ == '__main__':
208223
unittest_main()

0 commit comments

Comments
 (0)