Skip to content

Commit d42dcc5

Browse files
Implement VALUE_ROWIDS format for ragged to tensor (#1664)
* Implement RandomShuffle op Signed-off-by: Tom Wildenhain <[email protected]> * Implement VALUE_ROWIDS format for ragged to tensor Signed-off-by: Tom Wildenhain <[email protected]>
1 parent fe2a433 commit d42dcc5

File tree

2 files changed

+71
-4
lines changed

2 files changed

+71
-4
lines changed

tests/test_backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4694,6 +4694,19 @@ def func(splits1, splits2, rt_dense_values):
46944694
return tf.identity(y, name=_TFOUTPUT)
46954695
self._run_test_case(func, [_OUTPUT], {_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
46964696

4697+
@check_tf_min_version("1.14", "ragged needs tf 1.14")
4698+
@check_opset_min_version(11, "CumSum")
4699+
@skip_tflite("unknown rank")
4700+
def test_ragged_tensor_to_tensor_row_ids(self):
4701+
ids_val1 = np.array([0, 0, 0, 2, 2], dtype=np.int32)
4702+
ids_val2 = np.array([0, 0, 2, 2, 2, 3, 3, 4], dtype=np.int32)
4703+
dense_vals_val = np.array([10, 20, 30, 40, 50, 60, 70, 80], dtype=np.float32)
4704+
def func(ids1, ids2, rt_dense_values):
4705+
x = tf.RaggedTensor.from_nested_value_rowids(rt_dense_values, [ids1, ids2], [4, 5])
4706+
y = x.to_tensor(default_value=7)
4707+
return tf.identity(y, name=_TFOUTPUT)
4708+
self._run_test_case(func, [_OUTPUT], {_INPUT: ids_val1, _INPUT1: ids_val2, _INPUT2: dense_vals_val})
4709+
46974710
@check_tf_min_version("2.2", "ragged to_tensor with constrained shape")
46984711
@check_opset_min_version(11, "CumSum")
46994712
def test_ragged_tensor_to_tensor_constrain_shape(self):

tf2onnx/onnx_opset/tensor.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,6 +2385,25 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
23852385
return num_rows, num_cols, row_indices, col_indices
23862386

23872387

2388+
def ragged_row_ids_to_sparse_indices(ctx, row_ids):
2389+
_, indices, _, counts = ctx.make_node("Unique", [row_ids], attr={'axis': 0}, output_count=4).output
2390+
num_cols = ctx.make_node("ReduceMax", [counts], attr={'axes': [0], 'keepdims': True}).output[0]
2391+
const_one = ctx.make_const(utils.make_name("const_one"), np.array(1, np.int64)).output[0]
2392+
const_zero_unsq = ctx.make_const(utils.make_name("const_zero"), np.array([0], np.int64)).output[0]
2393+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
2394+
const_neg_one_unsq = ctx.make_const(utils.make_name("const_neg_one"), np.array([-1], np.int64)).output[0]
2395+
one_minus_cnt = ctx.make_node("Sub", [const_one, counts]).output[0]
2396+
cnts_prefixed = ctx.make_node("Concat", [const_zero_unsq, one_minus_cnt], attr={'axis': 0}).output[0]
2397+
cnts_shifted = GraphBuilder(ctx).make_slice(
2398+
{'data': cnts_prefixed, 'starts': const_zero_unsq, 'ends': const_neg_one_unsq, 'axes': [0]})
2399+
ids_shape = ctx.make_node("Shape", [row_ids]).output[0]
2400+
one_tensor = helper.make_tensor("value", onnx_pb.TensorProto.INT64, dims=[1], vals=[1])
2401+
ones_of_shape = ctx.make_node("ConstantOfShape", [ids_shape], attr={'value': one_tensor}).output[0]
2402+
deltas = ctx.make_node("ScatterElements", [ones_of_shape, indices, cnts_shifted], attr={'axis': 0}).output[0]
2403+
col_indices = ctx.make_node("CumSum", [deltas, const_zero]).output[0]
2404+
return num_cols, col_indices
2405+
2406+
23882407
def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope):
23892408
sparse_indices = None
23902409
dense_shape_dims = []
@@ -2412,6 +2431,28 @@ def ragged_nested_splits_to_sparse_indices(ctx, nested_splits, op_name_scope):
24122431
return sparse_indices, dense_shape
24132432

24142433

2434+
def ragged_nested_row_ids_to_sparse_indices(ctx, num_rows, nested_row_ids, op_name_scope):
2435+
sparse_indices = None
2436+
if ctx.get_dtype(num_rows) != TensorProto.INT64:
2437+
num_rows = ctx.make_node("Cast", [num_rows], attr={'to': TensorProto.INT64}).output[0]
2438+
num_rows = GraphBuilder(ctx).make_unsqueeze({"data": num_rows, "axes": [0]})
2439+
dense_shape_dims = [num_rows]
2440+
for row_ids in nested_row_ids:
2441+
if ctx.get_dtype(row_ids) != TensorProto.INT64:
2442+
row_ids = ctx.make_node("Cast", [row_ids], attr={'to': TensorProto.INT64}).output[0]
2443+
num_cols, col_indices = ragged_row_ids_to_sparse_indices(ctx, row_ids)
2444+
dense_shape_dims.append(num_cols)
2445+
if sparse_indices is None:
2446+
row_indices = GraphBuilder(ctx).make_unsqueeze({"data": row_ids, "axes": [1]})
2447+
else:
2448+
row_indices = ctx.make_node("Gather", [sparse_indices, row_ids]).output[0]
2449+
col_indices = GraphBuilder(ctx).make_unsqueeze({"data": col_indices, "axes": [1]})
2450+
sparse_indices = ctx.make_node("Concat", [row_indices, col_indices], attr={'axis': 1},
2451+
op_name_scope=op_name_scope).output[0]
2452+
dense_shape = ctx.make_node("Concat", dense_shape_dims, attr={'axis': 0}, op_name_scope=op_name_scope).output[0]
2453+
return sparse_indices, dense_shape
2454+
2455+
24152456
@tf_op("RaggedTensorToSparse")
24162457
class RaggedTensorToSparse:
24172458
@classmethod
@@ -2432,10 +2473,23 @@ class RaggedTensorToTensor:
24322473
def version_11(cls, ctx, node, **kwargs):
24332474
shape, values, default_value, *row_partition_tensors = node.input
24342475
partition_types = node.get_attr_value("row_partition_types")
2435-
error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r"
2436-
utils.make_sure(all(t == b'ROW_SPLITS' for t in partition_types), error_msg, partition_types)
2437-
nested_splits = row_partition_tensors
2438-
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
2476+
layout_type = None
2477+
if len(partition_types) >= 2 and partition_types[0] == b'FIRST_DIM_SIZE' and \
2478+
all(t == b'VALUE_ROWIDS' for t in partition_types[1:]):
2479+
layout_type = 'VALUE_ROWIDS'
2480+
elif all(t == b'ROW_SPLITS' for t in partition_types):
2481+
layout_type = 'ROW_SPLITS'
2482+
error_msg = "Only ROW_SPLITS partition and VALUE_ROWIDS types supported for RaggedTensorToTensor. types: %r"
2483+
2484+
if layout_type == 'ROW_SPLITS':
2485+
nested_splits = row_partition_tensors
2486+
sparse_indices, dense_shape = ragged_nested_splits_to_sparse_indices(ctx, nested_splits, node.name)
2487+
else:
2488+
utils.make_sure(layout_type == 'VALUE_ROWIDS', error_msg, partition_types)
2489+
first_dim = row_partition_tensors[0]
2490+
row_ids = row_partition_tensors[1:]
2491+
sparse_indices, dense_shape = ragged_nested_row_ids_to_sparse_indices(ctx, first_dim, row_ids, node.name)
2492+
24392493
# A shape of rank 0 means the natural shape should be used.
24402494
if ctx.get_rank(shape) != 0:
24412495
if ctx.get_dtype(shape) != TensorProto.INT64:

0 commit comments

Comments
 (0)