19
19
from tf2onnx import utils
20
20
from tf2onnx .handler import tf_op
21
21
from tf2onnx .tf_loader import find_function
22
+ from tf2onnx .graph_builder import GraphBuilder
22
23
23
24
24
25
logger = logging .getLogger (__name__ )
@@ -401,6 +402,7 @@ def version_7(cls, ctx, node, **kwargs):
401
402
cond_input_to_state_var = {}
402
403
scan_outputs = []
403
404
input_idx_to_remove = []
405
+ idx_to_ragged_writes = dict (body .ragged_variant_list_writes )
404
406
# remove TensorListReserve
405
407
for idx , name in enumerate (tf_while_inputs ):
406
408
if idx == 1 :
@@ -416,9 +418,15 @@ def version_7(cls, ctx, node, **kwargs):
416
418
# there is no equivalent step in onnx and we should remove it.
417
419
output_shape = None
418
420
output_dtype = n .get_attr_value ("element_dtype" )
421
+ is_ragged = False
419
422
if n .type == "TensorListReserve" and n .inputs [0 ].is_const () and not n .inputs [0 ].is_scalar ():
420
423
output_shape = [- 1 ] + n .inputs [0 ].get_tensor_value (as_list = True )
421
- scan_outputs .append ((idx , n , output_shape , output_dtype ))
424
+ if idx in idx_to_ragged_writes :
425
+ output_shape = None
426
+ output_dtype = body .get_dtype (idx_to_ragged_writes [idx ].input [0 ])
427
+ is_ragged = True
428
+ loop_vars .append (name )
429
+ scan_outputs .append ((idx , n , output_shape , output_dtype , is_ragged ))
422
430
continue
423
431
424
432
# tensor arrays we read from can't be loop_vars and we fetch them from the outer context instead
@@ -437,8 +445,29 @@ def version_7(cls, ctx, node, **kwargs):
437
445
del body .outputs [idx ]
438
446
439
447
scan_output_names = []
440
- # remove tensor array that are passed in to the loop
441
- for idx , n , output_shape , output_dtype in reversed (scan_outputs ):
448
+ ragged_scan_output_names = []
449
+ ragged_scan_output_to_len = {}
450
+
451
+ # remove tensor arrays that are passed in to the loop
452
+ for idx , n , output_shape , output_dtype , is_ragged in reversed (scan_outputs ):
453
+ if is_ragged :
454
+ out = n .output [0 ]
455
+ ctx .remove_node (n .name )
456
+ seq_empty = ctx .make_node ("SequenceEmpty" , [], attr = {'dtype' : output_dtype }, name = n .name ,
457
+ outputs = [out ], shapes = [None ], dtypes = [utils .SeqType (output_dtype )])
458
+ ctx .replace_all_inputs (n .output [0 ], seq_empty .output [0 ])
459
+ # Ragged tensors also must track the length of each row
460
+ output_shapes .append ([- 1 ])
461
+ output_dtypes .append (TensorProto .INT64 )
462
+ output_shapes [idx ] = None
463
+ output_dtypes [idx ] = utils .SeqType (output_dtype )
464
+ body_ragged_name = utils .make_name ("ragged_scan_output" )
465
+ external_ragged_name = utils .make_name ("ragged_output" )
466
+ scan_output_names .append (body_ragged_name )
467
+ output_names .append (external_ragged_name )
468
+ ragged_scan_output_names .append (body_ragged_name )
469
+ ragged_scan_output_to_len [output_names [idx ]] = external_ragged_name
470
+ continue
442
471
ctx .remove_node (n .name )
443
472
# make the node output bad
444
473
ctx .replace_all_inputs (n .output [0 ], "@@ALLOC" ) # ops=ctx.get_nodes()
@@ -475,11 +504,16 @@ def version_7(cls, ctx, node, **kwargs):
475
504
476
505
# shift output consumers
477
506
for k , v in output_map .items ():
478
- ctx .replace_all_inputs (k , v ) # ops=ctx.get_nodes()
507
+ if k not in ragged_scan_output_to_len .values ():
508
+ ctx .replace_all_inputs (k , v ) # ops=ctx.get_nodes()
509
+
510
+ ragged_scan_output_to_len = {output_map [k ]: output_map [v ] for k , v in ragged_scan_output_to_len .items ()}
479
511
480
512
wire_while_body (ctx , body , loop_node , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
481
- output_dtypes , body_name , node .name , cond_graph , tf_while_inputs , scan_output_names )
513
+ output_dtypes , body_name , node .name , cond_graph , tf_while_inputs , scan_output_names ,
514
+ ragged_scan_output_names )
482
515
516
+ loop_node .ragged_scan_output_to_len = ragged_scan_output_to_len
483
517
# if there was a tensorflow variant type, bind in a real type here
484
518
# FIXME: I don't think this is needed anymore
485
519
for i , n in enumerate (body .inputs ):
@@ -488,7 +522,8 @@ def version_7(cls, ctx, node, **kwargs):
488
522
489
523
490
524
def wire_while_body (parent_g , g , loop_node , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
491
- output_dtypes , scope , parent , cond_graph , tf_while_inputs , scan_output_names ):
525
+ output_dtypes , scope , parent , cond_graph , tf_while_inputs , scan_output_names ,
526
+ ragged_scan_output_names ):
492
527
"""Wire subgraph graph into main."""
493
528
remove_parents = []
494
529
to_remove = []
@@ -519,8 +554,25 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
519
554
520
555
# this is a tensor array write - make it an identity
521
556
scan_outputs = []
557
+ ragged_scan_outputs_cnt = 0
558
+ names_to_scan_outputs = {}
559
+
522
560
for node in g .get_nodes ():
523
561
if node .type == "TensorListSetItem" :
562
+ if node .inputs [2 ].type == "RaggedTensorToVariant" :
563
+ node .type = "SequenceInsert"
564
+ row_content = node .inputs [2 ].input [0 ]
565
+ g .replace_inputs (node , [node .input [0 ], row_content ])
566
+ g .set_shape (node .output [0 ], g .get_shape (node .input [1 ]))
567
+ g .set_dtype (node .output [0 ], utils .SeqType (g .get_dtype (node .input [1 ])))
568
+ dense_shape = g .make_node ("Shape" , [row_content ]).output [0 ]
569
+ zero_const = g .make_const (utils .make_name ("zero_const" ), np .array (0 , np .int64 )).output [0 ]
570
+ row_length = g .make_node ("Gather" , [dense_shape , zero_const ]).output [0 ]
571
+ row_length_id = g .make_node ("Identity" , [row_length ])
572
+ scan_outputs .append (row_length_id .output [0 ])
573
+ names_to_scan_outputs [ragged_scan_output_names [ragged_scan_outputs_cnt ]] = row_length_id .output [0 ]
574
+ ragged_scan_outputs_cnt += 1
575
+ continue
524
576
remove_parents .append (node .input [0 ])
525
577
node .type = "Identity"
526
578
g .set_shape (node .output [0 ], g .get_shape (node .input [2 ]))
@@ -531,8 +583,9 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
531
583
if len (scan_outputs ) != len (scan_output_names ):
532
584
raise ValueError ("While loop couldn't find scan output index for nodes" )
533
585
534
- names_to_scan_outputs = {}
535
586
for output in scan_outputs :
587
+ if output in names_to_scan_outputs .values ():
588
+ continue
536
589
last_output = output
537
590
consumers = g .find_output_consumers (last_output )
538
591
while consumers :
@@ -547,8 +600,9 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
547
600
548
601
# Reorder scan outputs
549
602
scan_outputs = [names_to_scan_outputs [name ] for name in scan_output_names ]
603
+
604
+ # Use shapes from subgraph if loop node shapes for scan outputs are missing
550
605
for i in range (- len (scan_output_names ), 0 ):
551
- # Use shapes from subgraph if loop node shapes for scan outputs are missing
552
606
if loop_node .output_shapes [i ] is None :
553
607
shape = g .get_shape (scan_outputs [i ])
554
608
if shape is not None :
@@ -580,6 +634,31 @@ def wire_while_body(parent_g, g, loop_node, body_input_to_state_var, cond_input_
580
634
if node .type in ["Identity" ]:
581
635
g .set_dtype (o , node .inputs [0 ].output_dtypes [0 ])
582
636
637
+ for node in g .ragged_variant_list_reads :
638
+ # Requires opset 11
639
+ gather = node .inputs [0 ]
640
+ inp = gather .inputs [0 ]
641
+ while inp .type == "Identity" :
642
+ inp = inp .inputs [0 ]
643
+ err_msg1 = "Could not find corresponding RaggedTensorToVariant for node %s" % node .name
644
+ err_msg2 = "Input to RaggedTensorToVariant for loop has batched_input=False for node %s" % inp .name
645
+ err_msg3 = "RAGGED_RANK != 1 for RaggedTensorToVariant node %s" % node .name
646
+ utils .make_sure (inp .type == "RaggedTensorToVariant" , err_msg1 )
647
+ utils .make_sure (inp .get_attr_value ("batched_input" ), err_msg2 )
648
+ utils .make_sure (inp .get_attr_value ("RAGGED_RANK" ) == 1 , err_msg3 )
649
+ idx = gather .input [1 ]
650
+ idx_unsq = GraphBuilder (g ).make_unsqueeze ({'data' : idx , 'axes' : [0 ]})
651
+ np_dtype = utils .map_onnx_to_numpy_type (g .get_dtype (idx_unsq ))
652
+ const_one = g .make_const (utils .make_name ("const_1" ), np .array (1 , np_dtype )).output [0 ]
653
+ idx_plus_1 = g .make_node ("Add" , [idx_unsq , const_one ]).output [0 ]
654
+ splits , values = inp .input
655
+ start = g .make_node ("Gather" , [splits , idx_unsq ]).output [0 ]
656
+ end = g .make_node ("Gather" , [splits , idx_plus_1 ]).output [0 ]
657
+ np_dtype2 = utils .map_onnx_to_numpy_type (g .get_dtype (splits ))
658
+ axes = g .make_const (utils .make_name ("const_zero" ), np .array ([0 ], np_dtype2 )).output [0 ]
659
+ sliced_vals = g .make_node ("Slice" , [values , start , end , axes ]).output [0 ]
660
+ g .replace_all_inputs (node .output [0 ], sliced_vals )
661
+
583
662
return g
584
663
585
664
0 commit comments