Skip to content

Commit 72fb208

Browse files
Implemented conversion of RaggedToSparse (#1276)
* Implemented conversion of RaggedToSparse Signed-off-by: Tom Wildenhain <[email protected]> * Fixed bug Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 8dbde42 commit 72fb208

File tree

2 files changed

+79
-20
lines changed

2 files changed

+79
-20
lines changed

tests/test_backend.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3765,6 +3765,23 @@ def func(indices, dense_shape, new_shape, shape_pad):
37653765
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: indices_val, _INPUT1: dense_shape_val,
37663766
_INPUT2: new_shape_val, _INPUT3: shape_pad_val})
37673767

3768+
@check_tf_min_version("1.14", "ragged needs tf 1.14")
3769+
@check_opset_min_version(11, "CumSum")
3770+
def test_ragged_tensor_to_sparse(self):
3771+
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
3772+
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
3773+
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
3774+
def func(splits1, splits2, rt_dense_values):
3775+
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True)
3776+
s = x.to_sparse()
3777+
indices, values, shape = s.indices, s.values, s.dense_shape
3778+
indices = tf.identity(indices, name=_TFOUTPUT)
3779+
values = tf.identity(values, name=_TFOUTPUT1)
3780+
shape = tf.identity(shape, name=_TFOUTPUT2)
3781+
return indices, values, shape
3782+
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2],
3783+
{_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
3784+
37683785
@check_tf_min_version("1.14", "ragged needs tf 1.14")
37693786
@check_opset_min_version(11, "Range")
37703787
def test_ragged_range_float(self):

tf2onnx/onnx_opset/tensor.py

Lines changed: 62 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2036,6 +2036,65 @@ def version_11(cls, ctx, node, **kwargs):
20362036
ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, sparse_vals])
20372037

20382038

2039+
def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
2040+
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
2041+
num_cols = ctx.make_node("ReduceMax", [ragged_lens], attr={'axes': [0], 'keeepdims': True}).output[0]
2042+
num_rows = ctx.make_node("Shape", [ragged_lens]).output[0]
2043+
range_len = ctx.make_node("Mul", [num_cols, num_rows]).output[0]
2044+
2045+
# ORT seems to have a shape inference bug for the Range node. Use CumSum instead.
2046+
one_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1])
2047+
ones_of_shape = ctx.make_node("ConstantOfShape", [range_len], attr={"value": one_tensor}).output[0]
2048+
range_node = ctx.make_node("CumSum", [ones_of_shape, const_zero_int64], attr={'exclusive': True}).output[0]
2049+
#const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0]
2050+
#range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0]
2051+
2052+
col_indices_dense = ctx.make_node("Mod", [range_node, num_cols]).output[0]
2053+
row_indices_dense = ctx.make_node("Div", [range_node, num_cols]).output[0]
2054+
row_lens_dense = ctx.make_node("Gather", [ragged_lens, row_indices_dense]).output[0]
2055+
indices_to_keep = ctx.make_node("Less", [col_indices_dense, row_lens_dense]).output[0]
2056+
col_indices = ctx.make_node("Compress", [col_indices_dense, indices_to_keep]).output[0]
2057+
row_indices = ctx.make_node("Compress", [row_indices_dense, indices_to_keep]).output[0]
2058+
return num_rows, num_cols, row_indices, col_indices
2059+
2060+
2061+
@tf_op("RaggedTensorToSparse")
2062+
class RaggedTensorToSparse:
2063+
@classmethod
2064+
def version_11(cls, ctx, node, **kwargs):
2065+
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
2066+
dense_values = node.inputs[-1]
2067+
nested_splits = node.inputs[:-1]
2068+
sparse_indices = None
2069+
dense_shape_dims = []
2070+
for split in nested_splits:
2071+
if ctx.get_dtype(split.output[0]) != TensorProto.INT64:
2072+
split = ctx.make_node("Cast", [split.output[0]], attr={'to': TensorProto.INT64})
2073+
max_int64 = int(utils.get_max_value(np.int64))
2074+
slice1 = GraphBuilder(ctx).make_slice(
2075+
{"data": split.output[0], "ends": [max_int64], "starts": [1], "axes": [0]})
2076+
slice2 = GraphBuilder(ctx).make_slice(
2077+
{"data": split.output[0], "ends": [-1], "starts": [0], "axes": [0]})
2078+
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
2079+
num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens)
2080+
if not dense_shape_dims:
2081+
dense_shape_dims.append(num_rows)
2082+
dense_shape_dims.append(num_cols)
2083+
if sparse_indices is None:
2084+
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]})
2085+
else:
2086+
row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0]
2087+
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
2088+
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
2089+
op_name_scope=node.name).output[0]
2090+
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=node.name).output[0]
2091+
2092+
ctx.replace_all_inputs(node.output[0], sparse_indices)
2093+
ctx.replace_all_inputs(node.output[1], dense_values.output[0])
2094+
ctx.replace_all_inputs(node.output[2], dense_shape)
2095+
ctx.remove_node(node.name)
2096+
2097+
20392098
@tf_op("RaggedRange")
20402099
class RaggedRange:
20412100
@classmethod
@@ -2076,34 +2135,17 @@ def version_11(cls, ctx, node, **kwargs):
20762135

20772136
const_zero_list = ctx.make_const(utils.make_name("const_zero_list"), np.array([0], dtype=np.int64)).output[0]
20782137

2079-
max_row_len = ctx.make_node("ReduceMax", [row_lens], attr={'axes': [0], 'keeepdims': False}).output[0]
2080-
inp_shape = ctx.make_node("Shape", [row_lens]).output[0]
2081-
range_len = ctx.make_node("Mul", [max_row_len, inp_shape]).output[0]
2082-
2083-
# ORT seems to have a shape inference bug for the Range node. Use CumSum instead.
2084-
one_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1])
2085-
ones_of_shape = ctx.make_node("ConstantOfShape", [range_len], attr={"value": one_tensor}).output[0]
2086-
range_node = ctx.make_node("CumSum", [ones_of_shape, const_zero_int64], attr={'exclusive': True}).output[0]
2087-
#const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0]
2088-
#range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0]
2089-
2090-
col_indices_dense = ctx.make_node("Mod", [range_node, max_row_len]).output[0]
2091-
row_indices_dense = ctx.make_node("Div", [range_node, max_row_len]).output[0]
2092-
row_lens_dense = ctx.make_node("Gather", [row_lens, row_indices_dense]).output[0]
2093-
indices_to_keep = ctx.make_node("Less", [col_indices_dense, row_lens_dense]).output[0]
2094-
col_indices = ctx.make_node("Compress", [col_indices_dense, indices_to_keep]).output[0]
2095-
row_indices = ctx.make_node("Compress", [row_indices_dense, indices_to_keep]).output[0]
2096-
2138+
num_rows, _, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, row_lens)
20972139

20982140
split_ends = ctx.make_node("CumSum", [row_lens, const_zero_int64]).output[0]
20992141
splits_out = ctx.make_node("Concat", [const_zero_list, split_ends], attr={'axis': 0}).output[0]
21002142
col_indices_cast = ctx.make_node("Cast", [col_indices], attr={'to': data_dtype}).output[0]
21012143

21022144
if ctx.get_rank(starts) != 1:
2103-
starts = ctx.make_node("Expand", [starts, inp_shape]).output[0]
2145+
starts = ctx.make_node("Expand", [starts, num_rows]).output[0]
21042146

21052147
if ctx.get_rank(deltas) != 1:
2106-
deltas = ctx.make_node("Expand", [deltas, inp_shape]).output[0]
2148+
deltas = ctx.make_node("Expand", [deltas, num_rows]).output[0]
21072149

21082150
gather_starts = ctx.make_node("Gather", [starts, row_indices]).output[0]
21092151
gather_deltas = ctx.make_node("Gather", [deltas, row_indices]).output[0]

0 commit comments

Comments
 (0)