Skip to content

Commit 5fb9194

Browse files
Implement RaggedTensorToTensor conversion (#1332)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent c23ef70 commit 5fb9194

File tree

2 files changed

+82
-24
lines changed

2 files changed

+82
-24
lines changed

tests/test_backend.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3857,6 +3857,30 @@ def func(splits1, splits2, rt_dense_values):
38573857
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2],
38583858
{_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
38593859

3860+
@check_tf_min_version("1.14", "ragged needs tf 1.14")
3861+
@check_opset_min_version(11, "CumSum")
3862+
def test_ragged_tensor_to_tensor(self):
3863+
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
3864+
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
3865+
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
3866+
def func(splits1, splits2, rt_dense_values):
3867+
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True)
3868+
y = x.to_tensor(default_value=7)
3869+
return tf.identity(y, name=_TFOUTPUT)
3870+
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
3871+
3872+
@check_tf_min_version("2.2", "ragged to_tensor with constrained shape")
3873+
@check_opset_min_version(11, "CumSum")
3874+
def test_ragged_tensor_to_tensor_constrain_shape(self):
3875+
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
3876+
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
3877+
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
3878+
def func(splits1, splits2, rt_dense_values):
3879+
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True)
3880+
y = x.to_tensor(default_value=7, shape=[20, None, 2])
3881+
return tf.identity(y, name=_TFOUTPUT)
3882+
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
3883+
38603884
@check_tf_min_version("1.14", "ragged needs tf 1.14")
38613885
@check_opset_min_version(11, "Range")
38623886
def test_ragged_range_float(self):

tf2onnx/onnx_opset/tensor.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,43 +2116,77 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
21162116
return num_rows, num_cols, row_indices, col_indices
21172117

21182118

2119+
def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope):
2120+
sparse_indices = None
2121+
dense_shape_dims = []
2122+
for split in nested_splits:
2123+
if ctx.get_dtype(split) != TensorProto.INT64:
2124+
split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0]
2125+
max_int64 = int(utils.get_max_value(np.int64))
2126+
slice1 = GraphBuilder(ctx).make_slice(
2127+
{"data": split, "ends": [max_int64], "starts": [1], "axes": [0]})
2128+
slice2 = GraphBuilder(ctx).make_slice(
2129+
{"data": split, "ends": [-1], "starts": [0], "axes": [0]})
2130+
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
2131+
num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens)
2132+
if not dense_shape_dims:
2133+
dense_shape_dims.append(num_rows)
2134+
dense_shape_dims.append(num_cols)
2135+
if sparse_indices is None:
2136+
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]})
2137+
else:
2138+
row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0]
2139+
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
2140+
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
2141+
op_name_scope=op_name_scope).output[0]
2142+
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=op_name_scope).output[0]
2143+
return sparse_indices, dense_shape
2144+
2145+
21192146
@tf_op("RaggedTensorToSparse")
21202147
class RaggedTensorToSparse:
21212148
@classmethod
21222149
def version_11(cls, ctx, node, **kwargs):
21232150
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
21242151
dense_values = node.input[-1]
21252152
nested_splits = node.input[:-1]
2126-
sparse_indices = None
2127-
dense_shape_dims = []
2128-
for split in nested_splits:
2129-
if ctx.get_dtype(split) != TensorProto.INT64:
2130-
split = ctx.make_node("Cast", [split], attr={'to': TensorProto.INT64}).output[0]
2131-
max_int64 = int(utils.get_max_value(np.int64))
2132-
slice1 = GraphBuilder(ctx).make_slice(
2133-
{"data": split, "ends": [max_int64], "starts": [1], "axes": [0]})
2134-
slice2 = GraphBuilder(ctx).make_slice(
2135-
{"data": split, "ends": [-1], "starts": [0], "axes": [0]})
2136-
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
2137-
num_rows, num_cols, row_indices, col_indices = ragged_lengths_to_sparse_indices(ctx, ragged_lens)
2138-
if not dense_shape_dims:
2139-
dense_shape_dims.append(num_rows)
2140-
dense_shape_dims.append(num_cols)
2141-
if sparse_indices is None:
2142-
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_indices, "axes": [1]})
2143-
else:
2144-
row_indices = ctx.make_node("Gather", [sparse_indices, row_indices]).output[0]
2145-
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
2146-
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
2147-
op_name_scope=node.name).output[0]
2148-
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=node.name).output[0]
2149-
2153+
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
21502154
ctx.replace_all_inputs(node.output[0], sparse_indices)
21512155
ctx.replace_all_inputs(node.output[1], dense_values)
21522156
ctx.replace_all_inputs(node.output[2], dense_shape)
21532157
ctx.remove_node(node.name)
21542158

21552159

2160+
@tf_op("RaggedTensorToTensor")
2161+
class RaggedTensorToTensor:
2162+
@classmethod
2163+
def version_11(cls, ctx, node, **kwargs):
2164+
shape, values, default_value, *row_partition_tensors = node.input
2165+
partition_types = node.get_attr_value("row_partition_types")
2166+
error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r"
2167+
utils.make_sure(all(t == b'ROW_SPLITS' for t in partition_types), error_msg, partition_types)
2168+
nested_splits = row_partition_tensors
2169+
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
2170+
# A shape of rank 0 means the natural shape should be used.
2171+
if ctx.get_rank(shape) != 0:
2172+
if ctx.get_dtype(shape) != TensorProto.INT64:
2173+
shape = ctx.make_node("Cast", [shape], attr={'to': TensorProto.INT64}).output[0]
2174+
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
2175+
unspec_dims = ctx.make_node("Less", [shape, const_zero_int64]).output[0]
2176+
out_shape = ctx.make_node("Where", [unspec_dims, dense_shape, shape]).output[0]
2177+
out_shape_unsq = GraphBuilder(ctx).make_unsqueeze({'data': out_shape, 'axes': [0]})
2178+
amt_idx_in_bounds = ctx.make_node("Sub", [out_shape_unsq, sparse_indices]).output[0]
2179+
amt_in_bounds_flat = ctx.make_node("ReduceMin", [amt_idx_in_bounds], attr={'axes': [1], 'keepdims': False})
2180+
idx_in_bounds = ctx.make_node("Greater", [amt_in_bounds_flat.output[0], const_zero_int64]).output[0]
2181+
sparse_indices = ctx.make_node("Compress", [sparse_indices, idx_in_bounds], attr={'axis': 0}).output[0]
2182+
values = ctx.make_node("Compress", [values, idx_in_bounds], attr={'axis': 0}).output[0]
2183+
else:
2184+
out_shape = dense_shape
2185+
expand_node = ctx.make_node("Expand", [default_value, out_shape])
2186+
node.type = "ScatterND"
2187+
ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, values])
2188+
2189+
21562190
@tf_op("RaggedRange")
21572191
class RaggedRange:
21582192
@classmethod

0 commit comments

Comments
 (0)