Skip to content

Commit 5192a1b

Browse files
Fix make_const for strings (#1335)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent fa819fb commit 5192a1b

File tree

2 files changed

+20
-2
lines changed

2 files changed

+20
-2
lines changed

tests/test_internals.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,22 @@ def test_insert_node2(self):
108108
'n5_raw_output___3:0 -> n6 n5_raw_output___3:0 -> n5_graph_outputs_Identity__4 }'
109109
self.assertEqual(expected, result)
110110

111+
def test_make_const_string(self):
112+
graph_proto = self.sample_net()
113+
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)
114+
arr1 = np.array("test", np.object)
115+
arr2 = np.array([["A", "B"], ["C", "D"]], np.object)
116+
arr3 = np.array(b"test", np.object)
117+
arr4 = np.array([[b"A", b"B"], [b"C", b"D"]], np.object)
118+
const1 = g.make_const("const1", arr1)
119+
const2 = g.make_const("const2", arr2)
120+
const3 = g.make_const("const3", arr3)
121+
const4 = g.make_const("const4", arr4)
122+
np.testing.assert_equal(const1.get_tensor_value(False), arr1)
123+
np.testing.assert_equal(const2.get_tensor_value(False), arr2)
124+
np.testing.assert_equal(const3.get_tensor_value(False), arr1)
125+
np.testing.assert_equal(const4.get_tensor_value(False), arr2)
126+
111127
def test_remove_input(self):
112128
graph_proto = self.sample_net()
113129
g = GraphUtil.create_graph_from_onnx_graph(graph_proto)

tf2onnx/graph.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -561,11 +561,13 @@ def make_const(self, name, np_val, skip_conversion=False, raw=True):
561561
skip_conversion: bool, indicate whether this created node would be mapped during conversion.
562562
raw: whether to store data at field of raw_data or the specific field according to its dtype
563563
"""
564-
if raw and np_val.dtype != np.object:
564+
np_val_flat = np_val.flatten()
565+
is_bytes = np_val.dtype == np.object and len(np_val_flat) > 0 and isinstance(np_val_flat[0], bytes)
566+
if raw and not is_bytes:
565567
onnx_tensor = numpy_helper.from_array(np_val, name)
566568
else:
567569
onnx_tensor = helper.make_tensor(name, utils.map_numpy_to_onnx_dtype(np_val.dtype),
568-
np_val.shape, np_val, raw=False)
570+
np_val.shape, np_val_flat, raw=False)
569571
dtype = onnx_tensor.data_type
570572
node = self.make_node("Const", [], outputs=[name], name=name, attr={"value": onnx_tensor},
571573
skip_conversion=skip_conversion, dtypes=[dtype], infer_shape_dtype=False)

0 commit comments

Comments
 (0)