Skip to content

Commit 8c8e585

Browse files
Implement RaggedTensorToTensor for tensors with dense (uniform) dims (#1676)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent bdd7617 commit 8c8e585

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

tests/test_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4714,7 +4714,7 @@ def func(splits1, splits2, rt_dense_values):
47144714
def test_ragged_tensor_to_tensor_row_ids(self):
47154715
ids_val1 = np.array([0, 0, 0, 2, 2], dtype=np.int32)
47164716
ids_val2 = np.array([0, 0, 2, 2, 2, 3, 3, 4], dtype=np.int32)
4717-
dense_vals_val = np.array([10, 20, 30, 40, 50, 60, 70, 80], dtype=np.float32)
4717+
dense_vals_val = make_xval([8, 2, 3])
47184718
def func(ids1, ids2, rt_dense_values):
47194719
x = tf.RaggedTensor.from_nested_value_rowids(rt_dense_values, [ids1, ids2], [4, 5])
47204720
y = x.to_tensor(default_value=7)
@@ -4727,9 +4727,10 @@ def test_ragged_tensor_to_tensor_constrain_shape(self):
47274727
splits_val1 = np.array([0, 1, 1, 5], dtype=np.int32)
47284728
splits_val2 = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
47294729
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
4730+
dense_vals_val = make_xval([10, 2, 3])
47304731
def func(splits1, splits2, rt_dense_values):
47314732
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits1, splits2], validate=True)
4732-
y = x.to_tensor(default_value=7, shape=[20, None, 2])
4733+
y = x.to_tensor(default_value=7, shape=[20, None, 2, None, 3])
47334734
return tf.identity(y, name=_TFOUTPUT)
47344735
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
47354736

tf2onnx/onnx_opset/tensor.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,6 +2465,8 @@ def version_11(cls, ctx, node, **kwargs):
24652465
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
24662466
dense_values = node.input[-1]
24672467
nested_splits = node.input[:-1]
2468+
err_msg2 = "RaggedTensorToSparse conversion only supports tensors with no dense dimensions"
2469+
utils.make_sure(ctx.get_rank(dense_values) in [None, 1], err_msg2)
24682470
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
24692471
ctx.replace_all_inputs(node.output[0], sparse_indices)
24702472
ctx.replace_all_inputs(node.output[1], dense_values)
@@ -2477,6 +2479,7 @@ class RaggedTensorToTensor:
24772479
@classmethod
24782480
def version_11(cls, ctx, node, **kwargs):
24792481
shape, values, default_value, *row_partition_tensors = node.input
2482+
has_uniform_dims = ctx.get_rank(values) != 1
24802483
partition_types = node.get_attr_value("row_partition_types")
24812484
layout_type = None
24822485
if len(partition_types) >= 2 and partition_types[0] == b'FIRST_DIM_SIZE' and \
@@ -2488,17 +2491,25 @@ def version_11(cls, ctx, node, **kwargs):
24882491

24892492
if layout_type == 'ROW_SPLITS':
24902493
nested_splits = row_partition_tensors
2494+
n_dims = len(nested_splits) + 1
24912495
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
24922496
else:
24932497
utils.make_sure(layout_type == 'VALUE_ROWIDS', error_msg, partition_types)
24942498
first_dim = row_partition_tensors[0]
24952499
row_ids = row_partition_tensors[1:]
2500+
n_dims = len(row_ids) + 1
24962501
sparse_indices, dense_shape = ragged_nested_row_ids_to_sparse_indices(ctx, first_dim, row_ids, node.name)
24972502

24982503
# A shape of rank 0 means the natural shape should be used.
24992504
if ctx.get_rank(shape) != 0:
25002505
if ctx.get_dtype(shape) != TensorProto.INT64:
25012506
shape = ctx.make_node("Cast", [shape], attr={'to': TensorProto.INT64}).output[0]
2507+
if has_uniform_dims:
2508+
const_zero_unsq = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64)).output[0]
2509+
const_n_unsq = ctx.make_const(utils.make_name("const_num_dims"),
2510+
np.array([n_dims], dtype=np.int64)).output[0]
2511+
shape = GraphBuilder(ctx).make_slice(
2512+
{'data': shape, 'starts': const_zero_unsq, 'ends': const_n_unsq, 'axes': [0]})
25022513
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
25032514
unspec_dims = ctx.make_node("Less", [shape, const_zero_int64]).output[0]
25042515
out_shape = ctx.make_node("Where", [unspec_dims, dense_shape, shape]).output[0]
@@ -2510,6 +2521,16 @@ def version_11(cls, ctx, node, **kwargs):
25102521
values = ctx.make_node("Compress", [values, idx_in_bounds], attr={'axis': 0}).output[0]
25112522
else:
25122523
out_shape = dense_shape
2524+
2525+
if has_uniform_dims:
2526+
values_shape = ctx.make_node("Shape", [values]).output[0]
2527+
const_one_unsq = ctx.make_const(utils.make_name("const_one"), np.array([1], dtype=np.int64)).output[0]
2528+
max_int64 = np.array([utils.get_max_value(np.int64)], dtype=np.int64)
2529+
const_max_val_unsq = ctx.make_const(utils.make_name("max_int"), max_int64).output[0]
2530+
uniform_dims = GraphBuilder(ctx).make_slice(
2531+
{'data': values_shape, 'starts': const_one_unsq, 'ends': const_max_val_unsq, 'axes': [0]})
2532+
out_shape = ctx.make_node("Concat", [out_shape, uniform_dims], attr={'axis': 0}).output[0]
2533+
25132534
expand_node = ctx.make_node("Expand", [default_value, out_shape])
25142535
node.type = "ScatterND"
25152536
ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, values])

0 commit comments

Comments
 (0)