Skip to content

Commit 2259c3d

Browse files
Added support for converting RaggedRange (#1256)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 33ee88f commit 2259c3d

File tree

3 files changed

+129
-0
lines changed

3 files changed

+129
-0
lines changed

tests/test_backend.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3740,6 +3740,48 @@ def func(indices, dense_shape, new_shape, shape_pad):
37403740
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: indices_val, _INPUT1: dense_shape_val,
37413741
_INPUT2: new_shape_val, _INPUT3: shape_pad_val})
37423742

3743+
@check_tf_min_version("1.14", "ragged needs tf 1.14")
3744+
@check_opset_min_version(11, "Range")
3745+
def test_ragged_range_float(self):
3746+
starts_val = np.array([0, 0, 1, 10, 0.5, 0.5], dtype=np.float32)
3747+
limits_val = np.array([-5, -2, 7, 100, 1, 1], dtype=np.float32)
3748+
deltas_val = np.array([-1, 1, 2, 20, 1, 1.1], dtype=np.float32)
3749+
def func(starts, limits, deltas):
3750+
x = tf.ragged.range(starts, limits, deltas)
3751+
rt_nested_splits = tf.identity(x.row_splits, name=_TFOUTPUT)
3752+
rt_dense_values = tf.identity(x.flat_values, name=_TFOUTPUT1)
3753+
return rt_nested_splits, rt_dense_values
3754+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: starts_val, _INPUT1: limits_val,
3755+
_INPUT2: deltas_val})
3756+
3757+
@check_tf_min_version("1.14", "ragged needs tf 1.14")
3758+
@check_opset_min_version(11, "Range")
3759+
def test_ragged_range_int(self):
3760+
starts_val = np.array([0, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.int32)
3761+
limits_val = np.array([-6, -5, -4, -1, 0, 1, 4, 5, 6, 2, -2], dtype=np.int32)
3762+
deltas_val = np.array([-5, -5, -5, -5, 5, 5, 5, 5, 5, 1, -1], dtype=np.int32)
3763+
def func(starts, limits, deltas):
3764+
x = tf.ragged.range(starts, limits, deltas)
3765+
rt_nested_splits = tf.identity(x.row_splits, name=_TFOUTPUT)
3766+
rt_dense_values = tf.identity(x.flat_values, name=_TFOUTPUT1)
3767+
return rt_nested_splits, rt_dense_values
3768+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: starts_val, _INPUT1: limits_val,
3769+
_INPUT2: deltas_val})
3770+
3771+
@check_tf_min_version("1.14", "ragged needs tf 1.14")
3772+
@check_opset_min_version(11, "Range")
3773+
def test_ragged_range_scalar(self):
3774+
starts_val = np.array(0, dtype=np.int32)
3775+
limits_val = np.array([5, -1, -1, 2, 7, 100, 4, 5, 6], dtype=np.int32)
3776+
deltas_val = np.array(1, dtype=np.int32)
3777+
def func(starts, limits, deltas):
3778+
x = tf.ragged.range(starts, limits, deltas)
3779+
rt_nested_splits = tf.identity(x.row_splits, name=_TFOUTPUT)
3780+
rt_dense_values = tf.identity(x.flat_values, name=_TFOUTPUT1)
3781+
return rt_nested_splits, rt_dense_values
3782+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: starts_val, _INPUT1: limits_val,
3783+
_INPUT2: deltas_val})
3784+
37433785
@check_opset_min_version(9, "Compress")
37443786
def test_dynamic_partition_both_vector(self):
37453787
data_val = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32)

tf2onnx/graph.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,13 @@ def get_shape(self, name):
951951
return shape
952952
return shape
953953

954+
def get_rank(self, name):
955+
"""Returns len(get_shape(name)) or None if shape is None"""
956+
shape = self.get_shape(name)
957+
if shape is None:
958+
return None
959+
return len(shape)
960+
954961
def set_shape(self, name, val):
955962
"""Set new shape of node."""
956963
if isinstance(val, np.ndarray):

tf2onnx/onnx_opset/tensor.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2034,6 +2034,86 @@ def version_11(cls, ctx, node, **kwargs):
20342034
ctx.replace_inputs(node, [expand_node.output[0], sparse_indices, sparse_vals])
20352035

20362036

2037+
@tf_op("RaggedRange")
2038+
class RaggedRange:
2039+
@classmethod
2040+
def version_11(cls, ctx, node, **kwargs):
2041+
starts, limits, deltas = node.input
2042+
data_dtype = ctx.get_dtype(starts)
2043+
data_np_dtype = utils.map_onnx_to_numpy_type(data_dtype)
2044+
data_is_float = np.dtype(data_np_dtype).kind == 'f'
2045+
2046+
if data_is_float:
2047+
sub_node = ctx.make_node("Sub", [limits, starts]).output[0]
2048+
div_node = ctx.make_node("Div", [sub_node, deltas]).output[0]
2049+
ceil_node = ctx.make_node("Ceil", [div_node]).output[0]
2050+
row_lens = ctx.make_node("Cast", [ceil_node], attr={'to': TensorProto.INT64}).output[0]
2051+
2052+
else:
2053+
# compute ceil(a/b) with ints
2054+
starts_cast = ctx.make_node("Cast", [starts], attr={'to': TensorProto.INT64}).output[0]
2055+
limits_cast = ctx.make_node("Cast", [limits], attr={'to': TensorProto.INT64}).output[0]
2056+
deltas_cast = ctx.make_node("Cast", [deltas], attr={'to': TensorProto.INT64}).output[0]
2057+
sub_node = ctx.make_node("Sub", [limits_cast, starts_cast]).output[0]
2058+
div_node = ctx.make_node("Div", [sub_node, deltas_cast]).output[0]
2059+
mul_node = ctx.make_node("Mul", [div_node, deltas_cast]).output[0]
2060+
eq_node = ctx.make_node("Equal", [mul_node, sub_node]).output[0]
2061+
ne_node = ctx.make_node("Not", [eq_node]).output[0]
2062+
# we want to round up if it isn't evenly divisible
2063+
offset = ctx.make_node("Cast", [ne_node], attr={'to': TensorProto.INT64}).output[0]
2064+
row_lens = ctx.make_node("Add", [div_node, offset]).output[0]
2065+
2066+
const_zero_int64 = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.int64)).output[0]
2067+
if ctx.opset <= 11:
2068+
const_zero_double = ctx.make_const(utils.make_name("const_zero"), np.array(0, dtype=np.float64)).output[0]
2069+
row_lens = ctx.make_node("Cast", [row_lens], attr={'to': TensorProto.DOUBLE}).output[0]
2070+
row_lens = ctx.make_node("Max", [row_lens, const_zero_double]).output[0]
2071+
row_lens = ctx.make_node("Cast", [row_lens], attr={'to': TensorProto.INT64}).output[0]
2072+
else:
2073+
row_lens = ctx.make_node("Max", [row_lens, const_zero_int64]).output[0]
2074+
2075+
const_zero_list = ctx.make_const(utils.make_name("const_zero_list"), np.array([0], dtype=np.int64)).output[0]
2076+
2077+
max_row_len = ctx.make_node("ReduceMax", [row_lens], attr={'axes': [0], 'keeepdims': False}).output[0]
2078+
inp_shape = ctx.make_node("Shape", [row_lens]).output[0]
2079+
range_len = ctx.make_node("Mul", [max_row_len, inp_shape]).output[0]
2080+
2081+
# ORT seems to have a shape inference bug for the Range node. Use CumSum instead.
2082+
one_tensor = helper.make_tensor("value", TensorProto.INT64, dims=[1], vals=[1])
2083+
ones_of_shape = ctx.make_node("ConstantOfShape", [range_len], attr={"value": one_tensor}).output[0]
2084+
range_node = ctx.make_node("CumSum", [ones_of_shape, const_zero_int64], attr={'exclusive': True}).output[0]
2085+
#const_one_int64 = ctx.make_const(utils.make_name("const_one"), np.array(1, dtype=np.int64)).output[0]
2086+
#range_node = ctx.make_node("Range", [const_zero_int64, range_len, const_one_int64]).output[0]
2087+
2088+
col_indices_dense = ctx.make_node("Mod", [range_node, max_row_len]).output[0]
2089+
row_indices_dense = ctx.make_node("Div", [range_node, max_row_len]).output[0]
2090+
row_lens_dense = ctx.make_node("Gather", [row_lens, row_indices_dense]).output[0]
2091+
indices_to_keep = ctx.make_node("Less", [col_indices_dense, row_lens_dense]).output[0]
2092+
col_indices = ctx.make_node("Compress", [col_indices_dense, indices_to_keep]).output[0]
2093+
row_indices = ctx.make_node("Compress", [row_indices_dense, indices_to_keep]).output[0]
2094+
2095+
2096+
split_ends = ctx.make_node("CumSum", [row_lens, const_zero_int64]).output[0]
2097+
splits_out = ctx.make_node("Concat", [const_zero_list, split_ends], attr={'axis': 0}).output[0]
2098+
col_indices_cast = ctx.make_node("Cast", [col_indices], attr={'to': data_dtype}).output[0]
2099+
2100+
if ctx.get_rank(starts) != 1:
2101+
starts = ctx.make_node("Expand", [starts, inp_shape]).output[0]
2102+
2103+
if ctx.get_rank(deltas) != 1:
2104+
deltas = ctx.make_node("Expand", [deltas, inp_shape]).output[0]
2105+
2106+
gather_starts = ctx.make_node("Gather", [starts, row_indices]).output[0]
2107+
gather_deltas = ctx.make_node("Gather", [deltas, row_indices]).output[0]
2108+
2109+
mul_node = ctx.make_node("Mul", [col_indices_cast, gather_deltas], op_name_scope=node.name).output[0]
2110+
dense_vals_out = ctx.make_node("Add", [gather_starts, mul_node], op_name_scope=node.name).output[0]
2111+
2112+
ctx.replace_all_inputs(node.output[0], splits_out)
2113+
ctx.replace_all_inputs(node.output[1], dense_vals_out)
2114+
ctx.remove_node(node.name)
2115+
2116+
20372117
@tf_op("SparseReshape")
20382118
class SparseReshape:
20392119
@classmethod

0 commit comments

Comments
 (0)