Skip to content

Commit 66c3cd7

Browse files
authored
Merge pull request #1164 from onnx/tom/ImproveTfUtilsSparseTensors
Fix parsing of Sparse ops
2 parents a1e0e09 + 157d180 commit 66c3cd7

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

tf2onnx/tf_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,10 +321,13 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
321321
dtype = get_tf_node_attr(node, a)
322322
if dtype and not isinstance(dtype, list):
323323
dtypes[node.name] = map_tf_dtype(dtype)
324-
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx"}:
324+
elif a in {"output_type", "output_dtype", "out_type", "Tidx", "out_idx", "out_type", "internal_type",
325+
"Tsegmentids"}:
325326
# Tidx is used by Range
326327
# out_idx is used by ListDiff
327328
attr[a] = map_tf_dtype(get_tf_node_attr(node, a))
329+
elif a == "sparse_types":
330+
attr[a] = [map_tf_dtype(d) for d in get_tf_node_attr(node, a)]
328331
elif a == "shape":
329332
shape = get_tf_shape_attr(node)
330333
if shape is not None:

0 commit comments

Comments
 (0)