Skip to content

Commit 1e99bf4

Browse files
Fix bug in raggedToSparse and add casts to Unique for non-int64 ints (#1306)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 38c8835 commit 1e99bf4

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1960,9 +1960,15 @@ def version_11(cls, ctx, node, **kwargs):
19601960
node_inputs = node.input
19611961
node_outputs = node.output
19621962
ctx.remove_node(node_name)
1963+
if dtypes[0] in [TensorProto.INT32, TensorProto.INT16, TensorProto.UINT8, TensorProto.UINT16]:
1964+
inp_cast = ctx.make_node("Cast", [node_inputs[0]], attr={'to': TensorProto.INT64}).output[0]
1965+
node_inputs[0] = inp_cast
19631966
new_node = ctx.make_node("Unique", node_inputs, name=node_name, output_count=3, attr={'sorted': 0})
19641967
ctx.replace_all_inputs(node_outputs[0], new_node.output[0])
19651968
ctx.replace_all_inputs(node_outputs[1], new_node.output[2])
1969+
if ctx.get_dtype(new_node.output[0]) != dtypes[0]:
1970+
ctx.insert_new_node_on_output("Cast", new_node.output[0], name=utils.make_name(node.name) + "_cast",
1971+
to=dtypes[0])
19661972
if len(node_outputs) > 1:
19671973
# cast to int64 if needed
19681974
if dtypes[1] != onnx_pb.TensorProto.INT64:
@@ -2064,18 +2070,18 @@ class RaggedTensorToSparse:
20642070
@classmethod
20652071
def version_11(cls, ctx, node, **kwargs):
20662072
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
2067-
dense_values = node.inputs[-1]
2068-
nested_splits = node.inputs[:-1]
2073+
dense_values = node.input[-1]
2074+
nested_splits = node.input[:-1]
20692075
sparse_indices = None
20702076
dense_shape_dims = []
20712077
for split in nested_splits:
2072-
if ctx.get_dtype(split.output[0]) != TensorProto.INT64:
2073-
split = ctx.make_node("Cast", [split.output[0]], attr={'to': TensorProto.INT64})
2078+
if ctx.get_dtype(split) != TensorProto.INT64:
2079+
split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0]
20742080
max_int64 = int(utils.get_max_value(np.int64))
20752081
slice1 = GraphBuilder(ctx).make_slice(
2076-
{"data": split.output[0], "ends": [max_int64], "starts": [1], "axes": [0]})
2082+
{"data": split, "ends": [max_int64], "starts": [1], "axes": [0]})
20772083
slice2 = GraphBuilder(ctx).make_slice(
2078-
{"data": split.output[0], "ends": [-1], "starts": [0], "axes": [0]})
2084+
{"data": split, "ends": [-1], "starts": [0], "axes": [0]})
20792085
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
20802086
num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens)
20812087
if not dense_shape_dims:
@@ -2091,7 +2097,7 @@ def version_11(cls, ctx, node, **kwargs):
20912097
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=node.name).output[0]
20922098

20932099
ctx.replace_all_inputs(node.output[0], sparse_indices)
2094-
ctx.replace_all_inputs(node.output[1], dense_values.output[0])
2100+
ctx.replace_all_inputs(node.output[1], dense_values)
20952101
ctx.replace_all_inputs(node.output[2], dense_shape)
20962102
ctx.remove_node(node.name)
20972103

0 commit comments

Comments
 (0)