Skip to content

Commit 122556e

Browse files
Implement conversion of RaggedToVariant and RaggedFromVariant in loops (#1503)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 229985e commit 122556e

File tree

9 files changed

+250
-10
lines changed

9 files changed

+250
-10
lines changed

tests/backend_test_base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,16 @@ def get_shape(info):
272272
if not info.type.tensor_type.HasField("shape"):
273273
return None
274274
return [d.dim_value if d.HasField('dim_value') else -1 for d in info.type.tensor_type.shape.dim]
275+
def get_dtype(info):
276+
tensor_type = info.type.tensor_type
277+
is_seq = False
278+
result = None
279+
if info.type.HasField("sequence_type"):
280+
tensor_type = info.type.sequence_type.elem_type.tensor_type
281+
is_seq = True
282+
if tensor_type.HasField("elem_type"):
283+
result = tensor_type.elem_type
284+
return utils.SeqType(result) if is_seq else result
275285
for info in model_shapes.graph.value_info:
276286
if info.name == "":
277287
continue
@@ -289,7 +299,7 @@ def get_shape(info):
289299
self.assertEqual(d1, d2)
290300
else:
291301
self.assertEqual(onnx_shape, tf2onnx_shape)
292-
self.assertEqual(info.type.tensor_type.elem_type, graph.get_dtype(info.name))
302+
self.assertEqual(get_dtype(info), graph.get_dtype(info.name))
293303

294304
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port,
295305
rtol=1e-07, atol=1e-5, mtol=None, convert_var_to_const=True, constant_fold=True,

tests/test_backend.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4581,6 +4581,42 @@ def func(starts, limits, deltas):
45814581
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: starts_val, _INPUT1: limits_val,
45824582
_INPUT2: deltas_val})
45834583

4584+
@check_tf_min_version("2.0", "ragged variant needs tf 2.0")
4585+
@check_opset_min_version(13, "Loop over tensor sequences")
4586+
def test_ragged_to_variant(self):
4587+
splits_val = np.array([0, 3, 3, 5, 9, 10], dtype=np.int32)
4588+
dense_vals_val = np.arange(10 * 3 * 2, dtype=np.float32).reshape([10, 3, 2])
4589+
4590+
def fn(elem):
4591+
res = elem + elem * elem
4592+
return res
4593+
4594+
def func(splits, rt_dense_values):
4595+
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits], validate=True)
4596+
y = tf.map_fn(fn, x)
4597+
return tf.identity(y.row_splits, name=_TFOUTPUT), tf.identity(y.flat_values, name=_TFOUTPUT1)
4598+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: splits_val, _INPUT1: dense_vals_val})
4599+
4600+
@check_tf_min_version("2.0", "ragged variant needs tf 2.0")
4601+
@check_opset_min_version(13, "Loop over tensor sequences")
4602+
def test_ragged_to_variant_unknown_shape(self):
4603+
splits_val = np.array([0, 3, 3, 5, 9, 10], dtype=np.int64)
4604+
dense_vals_shape = np.array([10, 3, 2], dtype=np.int32)
4605+
splits_pads_val = np.array([[0, 0]], dtype=np.int32)
4606+
4607+
def fn(elem):
4608+
res = elem + elem * elem
4609+
return res
4610+
4611+
def func(splits, rt_dense_values_shape, splits_pads):
4612+
rt_dense_values = tf.ones(rt_dense_values_shape, dtype=tf.int32)
4613+
splits = tf.pad(splits, splits_pads)
4614+
x = tf.RaggedTensor.from_nested_row_splits(rt_dense_values, [splits], validate=True)
4615+
y = tf.map_fn(fn, x)
4616+
return tf.identity(y.row_splits, name=_TFOUTPUT), tf.identity(y.flat_values, name=_TFOUTPUT1)
4617+
self._run_test_case(func, [_OUTPUT, _OUTPUT1],
4618+
{_INPUT: splits_val, _INPUT1: dense_vals_shape, _INPUT2: splits_pads_val})
4619+
45844620
@check_opset_min_version(9, "Compress")
45854621
def test_dynamic_partition_both_vector(self):
45864622
data_val = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float32)

tf2onnx/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
469469
# Used by the tflite while loop handler
470470
self.scan_outputs = []
471471
self.func_inputs = []
472+
self.ragged_variant_list_reads = []
473+
self.ragged_variant_list_writes = []
472474

473475
self._target = set(target)
474476
self._dtypes = dtypes

tf2onnx/onnx_opset/controlflow.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from tf2onnx import utils
2020
from tf2onnx.handler import tf_op
2121
from tf2onnx.tf_loader import find_function
22+
from tf2onnx.graph_builder import GraphBuilder
2223

2324

2425
logger = logging.getLogger(__name__)
@@ -401,6 +402,7 @@ def version_7(cls, ctx, node, **kwargs):
401402
cond_input_to_state_var = {}
402403
scan_outputs = []
403404
input_idx_to_remove = []
405+
idx_to_ragged_writes = dict(body.ragged_variant_list_writes)
404406
# remove TensorListReserve
405407
for idx, name in enumerate(tf_while_inputs):
406408
if idx == 1:
@@ -416,9 +418,15 @@ def version_7(cls, ctx, node, **kwargs):
416418
# there is no equivalent step in onnx and we should remove it.
417419
output_shape = None
418420
output_dtype = n.get_attr_value("element_dtype")
421+
is_ragged = False
419422
if n.type == "TensorListReserve" and n.inputs[0].is_const() and not n.inputs[0].is_scalar():
420423
output_shape = [-1] + n.inputs[0].get_tensor_value(as_list=True)
421-
scan_outputs.append((idx, n, output_shape, output_dtype))
424+
if idx in idx_to_ragged_writes:
425+
output_shape = None
426+
output_dtype = body.get_dtype(idx_to_ragged_writes[idx].input[0])
427+
is_ragged = True
428+
loop_vars.append(name)
429+
scan_outputs.append((idx, n, output_shape, output_dtype, is_ragged))
422430
continue
423431

424432
# tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead
@@ -437,8 +445,29 @@ def version_7(cls, ctx, node, **kwargs):
437445
del body.outputs[idx]
438446

439447
scan_output_names = []
440-
# remove tensor array that are passed in to the loop
441-
for idx, n, output_shape, output_dtype in reversed(scan_outputs):
448+
ragged_scan_output_names = []
449+
ragged_scan_output_to_len = {}
450+
451+
# remove tensor arrays that are passed in to the loop
452+
for idx, n, output_shape, output_dtype, is_ragged in reversed(scan_outputs):
453+
if is_ragged:
454+
out = n.output[0]
455+
ctx.remove_node(n.name)
456+
seq_empty = ctx.make_node("SequenceEmpty", [], attr={'dtype': output_dtype}, name=n.name,
457+
outputs=[out], shapes=[None], dtypes=[utils.SeqType(output_dtype)])
458+
ctx.replace_all_inputs(n.output[0], seq_empty.output[0])
459+
# Ragged tensors also must track the length of each row
460+
output_shapes.append([-1])
461+
output_dtypes.append(TensorProto.INT64)
462+
output_shapes[idx] = None
463+
output_dtypes[idx] = utils.SeqType(output_dtype)
464+
body_ragged_name = utils.make_name("ragged_scan_output")
465+
external_ragged_name = utils.make_name("ragged_output")
466+
scan_output_names.append(body_ragged_name)
467+
output_names.append(external_ragged_name)
468+
ragged_scan_output_names.append(body_ragged_name)
469+
ragged_scan_output_to_len[output_names[idx]] = external_ragged_name
470+
continue
442471
ctx.remove_node(n.name)
443472
# make the node output bad
444473
ctx.replace_all_inputs(n.output[0], "@@ALLOC") # ops=ctx.get_nodes()
@@ -475,11 +504,16 @@ def version_7(cls, ctx, node, **kwargs):
475504

476505
# shift output consumers
477506
for k, v in output_map.items():
478-
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
507+
if k not in ragged_scan_output_to_len.values():
508+
ctx.replace_all_inputs(k, v) # ops=ctx.get_nodes()
509+
510+
ragged_scan_output_to_len = {output_map[k]: output_map[v] for k, v in ragged_scan_output_to_len.items()}
479511

480512
wire_while_body(ctx, body, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes,
481-
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names)
513+
output_dtypes, body_name, node.name, cond_graph, tf_while_inputs, scan_output_names,
514+
ragged_scan_output_names)
482515

516+
loop_node.ragged_scan_output_to_len = ragged_scan_output_to_len
483517
# if there was a tensorflow variant type, bind in a real type here
484518
# FIXME: I don't think this is needed anymore
485519
for i, n in enumerate(body.inputs):
@@ -488,7 +522,8 @@ def version_7(cls, ctx, node, **kwargs):
488522

489523

490524
def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_to_state_var, output_shapes,
491-
output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names):
525+
output_dtypes, scope, parent, cond_graph, tf_while_inputs, scan_output_names,
526+
ragged_scan_output_names):
492527
"""Wire subgraph graph into main."""
493528
remove_parents = []
494529
to_remove = []
@@ -519,8 +554,25 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
519554

520555
# this is a tensor array write - make it an identity
521556
scan_outputs = []
557+
ragged_scan_outputs_cnt = 0
558+
names_to_scan_outputs = {}
559+
522560
for node in g.get_nodes():
523561
if node.type == "TensorListSetItem":
562+
if node.inputs[2].type == "RaggedTensorToVariant":
563+
node.type = "SequenceInsert"
564+
row_content = node.inputs[2].input[0]
565+
g.replace_inputs(node, [node.input[0], row_content])
566+
g.set_shape(node.output[0], g.get_shape(node.input[1]))
567+
g.set_dtype(node.output[0], utils.SeqType(g.get_dtype(node.input[1])))
568+
dense_shape = g.make_node("Shape", [row_content]).output[0]
569+
zero_const = g.make_const(utils.make_name("zero_const"), np.array(0, np.int64)).output[0]
570+
row_length = g.make_node("Gather", [dense_shape, zero_const]).output[0]
571+
row_length_id = g.make_node("Identity", [row_length])
572+
scan_outputs.append(row_length_id.output[0])
573+
names_to_scan_outputs[ragged_scan_output_names[ragged_scan_outputs_cnt]] = row_length_id.output[0]
574+
ragged_scan_outputs_cnt += 1
575+
continue
524576
remove_parents.append(node.input[0])
525577
node.type = "Identity"
526578
g.set_shape(node.output[0], g.get_shape(node.input[2]))
@@ -531,8 +583,9 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
531583
if len(scan_outputs) != len(scan_output_names):
532584
raise ValueError("While loop couldn't find scan output index for nodes")
533585

534-
names_to_scan_outputs = {}
535586
for output in scan_outputs:
587+
if output in names_to_scan_outputs.values():
588+
continue
536589
last_output = output
537590
consumers = g.find_output_consumers(last_output)
538591
while consumers:
@@ -547,8 +600,9 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
547600

548601
# Reorder scan outputs
549602
scan_outputs = [names_to_scan_outputs[name] for name in scan_output_names]
603+
604+
# Use shapes from subgraph if loop node shapes for scan outputs are missing
550605
for i in range(-len(scan_output_names), 0):
551-
# Use shapes from subgraph if loop node shapes for scan outputs are missing
552606
if loop_node.output_shapes[i] is None:
553607
shape = g.get_shape(scan_outputs[i])
554608
if shape is not None:
@@ -580,6 +634,31 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
580634
if node.type in ["Identity"]:
581635
g.set_dtype(o, node.inputs[0].output_dtypes[0])
582636

637+
for node in g.ragged_variant_list_reads:
638+
# Requires opset 11
639+
gather = node.inputs[0]
640+
inp = gather.inputs[0]
641+
while inp.type == "Identity":
642+
inp = inp.inputs[0]
643+
err_msg1 = "Could not find corresponding RaggedTensorToVariant for node %s" % node.name
644+
err_msg2 = "Input to RaggedTensorToVariant for loop has batched_input=False for node %s" % inp.name
645+
err_msg3 = "RAGGED_RANK != 1 for RaggedTensorToVariant node %s" % node.name
646+
utils.make_sure(inp.type == "RaggedTensorToVariant", err_msg1)
647+
utils.make_sure(inp.get_attr_value("batched_input"), err_msg2)
648+
utils.make_sure(inp.get_attr_value("RAGGED_RANK") == 1, err_msg3)
649+
idx = gather.input[1]
650+
idx_unsq = GraphBuilder(g).make_unsqueeze({'data': idx, 'axes': [0]})
651+
np_dtype = utils.map_onnx_to_numpy_type(g.get_dtype(idx_unsq))
652+
const_one = g.make_const(utils.make_name("const_1"), np.array(1, np_dtype)).output[0]
653+
idx_plus_1 = g.make_node("Add", [idx_unsq, const_one]).output[0]
654+
splits, values = inp.input
655+
start = g.make_node("Gather", [splits, idx_unsq]).output[0]
656+
end = g.make_node("Gather", [splits, idx_plus_1]).output[0]
657+
np_dtype2 = utils.map_onnx_to_numpy_type(g.get_dtype(splits))
658+
axes = g.make_const(utils.make_name("const_zero"), np.array([0], np_dtype2)).output[0]
659+
sliced_vals = g.make_node("Slice", [values, start, end, axes]).output[0]
660+
g.replace_all_inputs(node.output[0], sliced_vals)
661+
583662
return g
584663

585664

tf2onnx/onnx_opset/tensor.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2584,6 +2584,63 @@ def version_11(cls, ctx, node, **kwargs):
25842584
ctx.remove_node(node.name)
25852585

25862586

2587+
@tf_op("RaggedTensorFromVariant")
2588+
class RaggedTensorFromVariant:
2589+
@classmethod
2590+
def version_13(cls, ctx, node, **kwargs):
2591+
inp = node.inputs[0]
2592+
if inp.is_while():
2593+
row_lengths = inp.ragged_scan_output_to_len.get(node.input[0])
2594+
utils.make_sure(row_lengths is not None, "Couldn't find lengths for %s node %s" % (node.type, node.name))
2595+
dense_values = ctx.make_node("ConcatFromSequence", [node.input[0]], attr={'axis': 0}).output[0]
2596+
const_zero = ctx.make_const(utils.make_name("const_zero"), np.array(0, np.int64)).output[0]
2597+
const_zero_unsq = ctx.make_const(utils.make_name("const_zero"), np.array([0], np.int64)).output[0]
2598+
row_splits = ctx.make_node("CumSum", [row_lengths, const_zero]).output[0]
2599+
row_splits_w_zero = ctx.make_node("Concat", [const_zero_unsq, row_splits], attr={'axis': 0}).output[0]
2600+
idx_dtype = ctx.get_dtype(node.output[0])
2601+
if idx_dtype != TensorProto.INT64:
2602+
row_splits_w_zero = ctx.make_node("Cast", [row_splits_w_zero], attr={'to': idx_dtype}).output[0]
2603+
ctx.replace_all_inputs(node.output[0], row_splits_w_zero)
2604+
ctx.replace_all_inputs(node.output[1], dense_values)
2605+
ctx.remove_node(node.name)
2606+
return
2607+
2608+
utils.make_sure(inp.type == "Gather", "RaggedTensorFromVariant only supported after TensorListGetItem")
2609+
variant = inp.inputs[0]
2610+
err_msg = "RaggedTensorFromVariant only supported if variant is a graph input"
2611+
# Variant input will be found during loop conversion
2612+
utils.make_sure(variant.type == "Placeholder", err_msg)
2613+
ctx.ragged_variant_list_reads.append(node)
2614+
2615+
2616+
@tf_op("RaggedTensorToVariant")
2617+
class RaggedTensorToVariant:
2618+
@classmethod
2619+
def version_13(cls, ctx, node, **kwargs):
2620+
cons = ctx.find_output_consumers(node.output[0])
2621+
err_msg = "RaggedTensorToVariant only supported as input/output to loops"
2622+
utils.make_sure(len(cons) == 1, err_msg)
2623+
if cons[0].type == "TensorListFromTensor":
2624+
# Will be delt with in loop
2625+
cons = ctx.find_output_consumers(cons[0].output[0])
2626+
utils.make_sure(all(n.is_while() for n in cons), err_msg)
2627+
return
2628+
utils.make_sure(cons[0].type == "TensorListSetItem", err_msg)
2629+
tensor_set_item = cons[0]
2630+
list_output = tensor_set_item.output[0]
2631+
cons = ctx.find_output_consumers(list_output)
2632+
while len(cons) == 1 and cons[0].type == "Identity":
2633+
list_output = cons[0].output[0]
2634+
cons = ctx.find_output_consumers(list_output)
2635+
utils.make_sure(not cons, err_msg)
2636+
utils.make_sure(list_output in ctx.outputs, err_msg)
2637+
err_msg2 = "RaggedTensorToVariant within loop requires RAGGED_RANK=0"
2638+
err_msg3 = "RaggedTensorToVariant within loop requires batched_input=False"
2639+
utils.make_sure(node.get_attr_value("RAGGED_RANK") == 0, err_msg2)
2640+
utils.make_sure(not node.get_attr_value("batched_input"), err_msg3)
2641+
ctx.ragged_variant_list_writes.append((ctx.outputs.index(list_output), node))
2642+
2643+
25872644
@tf_op("SparseReshape")
25882645
class SparseReshape:
25892646
@classmethod

tf2onnx/rewriter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tf2onnx.rewriter.conv2d_with_add_rewriter import rewrite_biasadd_with_conv2d
2525
from tf2onnx.rewriter.quantization_ops_rewriter import rewrite_quantize_and_dequantize
2626
from tf2onnx.rewriter.layer_normalization_rewriter import rewrite_layer_normalization
27+
from tf2onnx.rewriter.ragged_variant_shape_rewriter import rewrite_ragged_variant_shape
2728

2829

2930
__all__ = [
@@ -48,5 +49,6 @@
4849
"rewrite_biasadd_with_conv2d",
4950
"rewrite_quantize_and_dequantize",
5051
"rewrite_layer_normalization",
51-
"rewrite_conv_dilations"
52+
"rewrite_conv_dilations",
53+
"rewrite_ragged_variant_shape"
5254
]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.rewriter - RaggedTensorToVariant -> Shape pattern
6+
"""
7+
8+
import numpy as np
9+
from tf2onnx import utils
10+
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
11+
12+
13+
# pylint: disable=missing-docstring
14+
15+
16+
def rewrite_ragged_variant_shape(g, ops):
17+
pattern1 = \
18+
OpTypePattern('Shape', name='shape', inputs=[
19+
OpTypePattern('RaggedTensorToVariant', name='raggedtovariant')
20+
])
21+
22+
pattern_list = [pattern1]
23+
for pattern in pattern_list:
24+
matcher = GraphMatcher(pattern)
25+
match_results = list(matcher.match_ops(ops))
26+
for match in match_results:
27+
shape = match.get_op('shape')
28+
raggedtovariant = match.get_op('raggedtovariant')
29+
if raggedtovariant.get_attr_value("batched_input") != 1:
30+
continue
31+
if raggedtovariant.get_attr_value("RAGGED_RANK") != 1:
32+
continue
33+
# Shape of batched variant from ragged is same as number of splits minus 1
34+
g.replace_inputs(shape, [raggedtovariant.input[0]])
35+
np_dtype = utils.map_onnx_to_numpy_type(g.get_dtype(shape.output[0]))
36+
const_one = g.make_const(utils.make_name("const_one"), np.array(1, np_dtype)).output[0]
37+
g.insert_new_node_on_output("Sub", shape.output[0], inputs=[shape.output[0], const_one])
38+
39+
return ops

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ def compat_handler(ctx, node, **kwargs):
619619
rewrite_biasadd_with_conv2d,
620620
rewrite_layer_normalization,
621621
rewrite_gemm,
622+
rewrite_ragged_variant_shape,
622623
]
623624

624625
if custom_rewriter is not None:

0 commit comments

Comments
 (0)