Skip to content

Commit fa819fb

Browse files
Implement conversion of RaggedGather (#1333)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 5fb9194 commit fa819fb

File tree

2 files changed

+86
-0
lines changed

2 files changed

+86
-0
lines changed

tests/test_backend.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3857,6 +3857,21 @@ def func(splits1, splits2, rt_dense_values):
38573857
self._run_test_case(func, [_OUTPUT, _OUTPUT1, _OUTPUT2],
38583858
{_INPUT: splits_val1, _INPUT1: splits_val2, _INPUT2: dense_vals_val})
38593859

3860+
@check_tf_min_version("1.14", "ragged needs tf 1.14")
3861+
@check_opset_min_version(11, "CumSum")
3862+
def test_ragged_gather(self):
3863+
splits_val = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
3864+
dense_vals_val = np.array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=np.float32)
3865+
indices_val = np.array([1, 3, 2, 0, 1, 1, 4, 3, 3], dtype=np.int32)
3866+
def func(splits, rt_dense_values, indices):
3867+
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits], validate=True)
3868+
g = tf.gather(x, indices)
3869+
rt_nested_splits = tf.identity(g.row_splits, name=_TFOUTPUT)
3870+
rt_dense_values = tf.identity(g.flat_values, name=_TFOUTPUT1)
3871+
return rt_nested_splits, rt_dense_values
3872+
self._run_test_case(func, [_OUTPUT, _OUTPUT1],
3873+
{_INPUT: splits_val, _INPUT1: dense_vals_val, _INPUT2: indices_val})
3874+
38603875
@check_tf_min_version("1.14", "ragged needs tf 1.14")
38613876
@check_opset_min_version(11, "CumSum")
38623877
def test_ragged_tensor_to_tensor(self):

tf2onnx/onnx_opset/tensor.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2250,6 +2250,77 @@ def version_11(cls, ctx, node, **kwargs):
22502250
ctx.remove_node(node.name)
22512251

22522252

2253+
@tf_op("RaggedGather")
2254+
class RaggedGather:
2255+
@classmethod
2256+
def version_11(cls, ctx, node, **kwargs):
2257+
*params_nested_splits, params_dense_values, indices = node.input
2258+
inp_ragged_rank = node.get_attr_value("PARAMS_RAGGED_RANK")
2259+
out_ragged_rank = node.get_attr_value("OUTPUT_RAGGED_RANK")
2260+
err_msg = "RaggedGather conversion only supports ragged rank of 1"
2261+
utils.make_sure(inp_ragged_rank == 1 and out_ragged_rank == 1 and len(params_nested_splits) == 1, err_msg)
2262+
splits = params_nested_splits[0]
2263+
err_msg2 = "RaggedGather conversion only supports tensors with no dense dimensions"
2264+
utils.make_sure(ctx.get_rank(splits) in [None, 1] and ctx.get_rank(params_dense_values) in [None, 1], err_msg2)
2265+
splits_dtype = ctx.get_dtype(splits)
2266+
2267+
if splits_dtype != TensorProto.INT64:
2268+
splits_64 = ctx.make_node("Cast", [splits], attr={'to': TensorProto.INT64}).output[0]
2269+
else:
2270+
splits_64 = splits
2271+
2272+
max_int64 = int(utils.get_max_value(np.int64))
2273+
slice1 = GraphBuilder(ctx).make_slice(
2274+
{"data": splits_64, "ends": [max_int64], "starts": [1], "axes": [0]})
2275+
slice2 = GraphBuilder(ctx).make_slice(
2276+
{"data": splits_64, "ends": [-1], "starts": [0], "axes": [0]})
2277+
ragged_lens = ctx.make_node("Sub", [slice1, slice2]).output[0]
2278+
2279+
gathered_lens = ctx.make_node("Gather", [ragged_lens, indices], op_name_scope=node.name).output[0]
2280+
2281+
const_zero_unsq = ctx.make_const(utils.make_name("const_zero"), np.array([0], dtype=np.int64)).output[0]
2282+
const_one_unsq = ctx.make_const(utils.make_name("const_one"), np.array([1], dtype=np.int64)).output[0]
2283+
gathered_lens_w_zero = ctx.make_node("Concat", [const_zero_unsq, gathered_lens], attr={'axis': 0}).output[0]
2284+
2285+
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
2286+
const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0]
2287+
2288+
gathered_splits = ctx.make_node("CumSum", [gathered_lens_w_zero, const_zero_int64]).output[0]
2289+
if splits_dtype != TensorProto.INT64:
2290+
output_splits = ctx.make_node("Cast", [gathered_splits], attr={'to': splits_dtype}).output[0]
2291+
else:
2292+
output_splits = gathered_splits
2293+
2294+
# Now that we have the splits, we just need to make the list of values.
2295+
total_length = GraphBuilder(ctx).make_slice(
2296+
{"data": gathered_splits, "ends": [max_int64], "starts": [-1], "axes": [0]})
2297+
gathered_starts = ctx.make_node("Gather", [splits_64, indices], op_name_scope=node.name).output[0]
2298+
# We disregard any length 0 segments
2299+
non_zero_pos = ctx.make_node("Greater", [gathered_lens, const_zero_int64]).output[0]
2300+
non_zero_lens = ctx.make_node("Compress", [gathered_lens, non_zero_pos]).output[0]
2301+
non_zero_lens_shifted = ctx.make_node("Concat", [const_zero_unsq, non_zero_lens], attr={'axis': 0}).output[0]
2302+
non_zero_prev_lens = GraphBuilder(ctx).make_slice(
2303+
{"data": non_zero_lens_shifted, "ends": [-1], "starts": [0], "axes": [0]})
2304+
non_zero_starts = ctx.make_node("Compress", [gathered_starts, non_zero_pos]).output[0]
2305+
non_zero_splits = ctx.make_node("Compress", [gathered_splits, non_zero_pos]).output[0]
2306+
2307+
prev_starts = GraphBuilder(ctx).make_slice(
2308+
{"data": non_zero_starts, "ends": [-1], "starts": [0], "axes": [0]})
2309+
prev_starts_concat = ctx.make_node("Concat", [const_one_unsq, prev_starts], attr={'axis': 0}).output[0]
2310+
deltas = ctx.make_node("Sub", [non_zero_starts, prev_starts_concat]).output[0]
2311+
deltas2 = ctx.make_node("Sub", [deltas, non_zero_prev_lens]).output[0]
2312+
deltas3 = ctx.make_node("Add", [deltas2, const_one_int64]).output[0]
2313+
one_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1])
2314+
ones_of_shape = ctx.make_node("ConstantOfShape", [total_length], attr={"value": one_tensor}).output[0]
2315+
full_deltas = ctx.make_node("ScatterElements", [ones_of_shape, non_zero_splits, deltas3], attr={'axis': 0})
2316+
full_indices = ctx.make_node("CumSum", [full_deltas.output[0], const_zero_int64]).output[0]
2317+
output_values = ctx.make_node("Gather", [params_dense_values, full_indices], op_name_scope=node.name).output[0]
2318+
2319+
ctx.replace_all_inputs(node.output[0], output_splits)
2320+
ctx.replace_all_inputs(node.output[1], output_values)
2321+
ctx.remove_node(node.name)
2322+
2323+
22532324
@tf_op("SparseReshape")
22542325
class SparseReshape:
22552326
@classmethod

0 commit comments

Comments
 (0)