|
18 | 18 | import tensorflow as tf
|
19 | 19 | from tf2onnx import utils
|
20 | 20 | from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
|
21 |
| -from tf2onnx.graph import GraphUtil |
| 21 | +from tf2onnx.graph import Node, GraphUtil |
22 | 22 | from common import unittest_main
|
23 | 23 |
|
24 | 24 |
|
@@ -203,6 +203,21 @@ def test_shape_utils(self):
|
203 | 203 | self.assertFalse(utils.are_shapes_equal(None, []))
|
204 | 204 | self.assertTrue(utils.are_shapes_equal([1, 2, 3], (1, 2, 3)))
|
205 | 205 |
|
| 206 | + def test_data_format(self): |
| 207 | + n1 = helper.make_node("Conv", ["X", "W"], ["Y"], name="n1", data_format="NHWC") |
| 208 | + graph_proto = helper.make_graph( |
| 209 | + nodes=[n1], |
| 210 | + name="test", |
| 211 | + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 2]), |
| 212 | + helper.make_tensor_value_info("W", TensorProto.FLOAT, [2, 2])], |
| 213 | + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 2])], |
| 214 | + initializer=[] |
| 215 | + ) |
| 216 | + g = GraphUtil.create_graph_from_onnx_graph(graph_proto) |
| 217 | + n = g.get_node_by_name("n1") |
| 218 | + self.assertEqual(n.data_format, "NHWC") |
| 219 | + self.assertTrue(n.is_nhwc()) |
| 220 | + |
206 | 221 |
|
207 | 222 | if __name__ == '__main__':
|
208 | 223 | unittest_main()
|
0 commit comments