@@ -421,7 +421,7 @@ def version_7(cls, ctx, node, **kwargs):
421
421
del output_names [idx ]
422
422
del body .outputs [idx ]
423
423
424
- removed_scan_outputs = {}
424
+ scan_output_names = []
425
425
# remove tensor array that are passed in to the loop
426
426
for idx , n in reversed (to_remove ):
427
427
ctx .remove_node (n .name )
@@ -430,19 +430,15 @@ def version_7(cls, ctx, node, **kwargs):
430
430
del body .func_inputs [idx ]
431
431
del cond_graph .func_inputs [idx ]
432
432
del tf_while_inputs [idx ]
433
- # save the index of the scan output
434
- removed_scan_outputs [body .outputs [idx ]] = idx
433
+ scan_output_names .append (body .outputs [idx ])
435
434
del body .outputs [idx ]
436
- # FIXME: Output shapes may be in wrong order if there are multiple scan outputs
437
435
output_shapes .append (output_shapes [idx ])
438
436
output_dtypes .append (output_dtypes [idx ])
439
437
output_names .append (output_names [idx ])
440
438
del output_shapes [idx ]
441
439
del output_dtypes [idx ]
442
440
del output_names [idx ]
443
441
444
- utils .make_sure (len (removed_scan_outputs ) <= 1 , "converter only supports while loops with a single scan output" )
445
-
446
442
ctx .remove_node (node .name )
447
443
448
444
# In onnx 'cond' is a variable, not a function. We need to inject the subgraph into the main graph
@@ -467,7 +463,7 @@ def version_7(cls, ctx, node, **kwargs):
467
463
ctx .replace_all_inputs (k , v ) # ops=ctx.get_nodes()
468
464
469
465
wire_while_body (ctx , body , loop_node .inputs , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
470
- output_dtypes , body_name , node .name , cond_graph , tf_while_inputs , removed_scan_outputs )
466
+ output_dtypes , body_name , node .name , cond_graph , tf_while_inputs , scan_output_names )
471
467
472
468
# if there was a tensorflow variant type, bind in a real type here
473
469
# FIXME: I don't think this is needed anymore
@@ -477,7 +473,7 @@ def version_7(cls, ctx, node, **kwargs):
477
473
478
474
479
475
def wire_while_body (parent_g , g , loop_node_inputs , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
480
- output_dtypes , scope , parent , cond_graph , tf_while_inputs , removed_scan_outputs ):
476
+ output_dtypes , scope , parent , cond_graph , tf_while_inputs , scan_output_names ):
481
477
"""Wire subgraph graph into main."""
482
478
remove_parents = []
483
479
to_remove = []
@@ -521,9 +517,10 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
521
517
g .replace_inputs (node , [node .input [2 ]])
522
518
scan_outputs .append (node .output [0 ])
523
519
524
- if len (scan_outputs ) != len (removed_scan_outputs ):
520
+ if len (scan_outputs ) != len (scan_output_names ):
525
521
raise ValueError ("While loop couldn't find scan output index for nodes" )
526
522
523
+ names_to_scan_outputs = {}
527
524
for output in scan_outputs :
528
525
last_output = output
529
526
consumers = g .find_output_consumers (last_output )
@@ -533,10 +530,12 @@ def wire_while_body(parent_g, g, loop_node_inputs, body_input_to_state_var, cond
533
530
raise ValueError ("While loop couldn't find scan output index for node " + node .name )
534
531
last_output = node .output [0 ]
535
532
consumers = g .find_output_consumers (last_output )
536
- if last_output not in removed_scan_outputs :
533
+ if last_output not in scan_output_names :
537
534
raise ValueError ("While loop couldn't find scan output index for node " + node .name )
538
- # TODO: store index to ensure scan outputs are in correct order for multiple outputs
539
- # initial_output_index = removed_scan_outputs[last_output]
535
+ names_to_scan_outputs [last_output ] = output
536
+
537
+ # Reorder scan outputs
538
+ scan_outputs = [names_to_scan_outputs [name ] for name in scan_output_names ]
540
539
541
540
# remove all nodes feeding to TensorListSetItem's reserved tensor
542
541
while remove_parents :
0 commit comments