@@ -34,7 +34,7 @@ def get_inputs_for_current_iteration(g, input_id, iter_index):
34
34
35
35
36
36
def create_loop_body_graph (parent_g , gather_input_ids , output_data_type , output_shape , trip_count_input_ids ,
37
- rank , loop_name ):
37
+ rank ):
38
38
g = parent_g .create_new_graph_with_same_config ()
39
39
g .parent_graph = parent_g
40
40
iter_name = utils .make_name ("i" )
@@ -112,9 +112,9 @@ def create_if_op(g, input_ids, output_data_type, output_shape):
112
112
out_name = utils .port_name (op_name )
113
113
114
114
# output a scalar
115
- if_node = g . make_node ( "If" , [ input_ids [ 0 ]], outputs = [ out_name ], name = op_name , skip_conversion = True )
116
- if_node . set_body_graph_as_attr ( "then_branch " , true_graph )
117
- if_node . set_body_graph_as_attr ( "else_branch" , false_graph )
115
+ branches = { "then_branch" : true_graph , "else_branch" : false_graph }
116
+ if_node = g . make_node ( "If " , [ input_ids [ 0 ]], outputs = [ out_name ], name = op_name ,
117
+ skip_conversion = True , branches = branches )
118
118
return if_node , out_name
119
119
120
120
@@ -152,12 +152,11 @@ def create_loop_op(g, gather_input_ids, output_type, output_shape, trip_count_in
152
152
cond_var_name , # termination condition
153
153
fake_val_name # initial value of loop-carried dependencies
154
154
]
155
+ loop_body = create_loop_body_graph (g , gather_input_ids , output_type , output_shape , trip_count_input_ids , rank )
155
156
# define an extra scan output
157
+ branches = {"body" : loop_body }
156
158
loop_node = g .make_node ("Loop" , loop_inputs , output_count = 2 , op_name_scope = "select_loop" ,
157
- skip_conversion = False )
158
- loop_body = create_loop_body_graph (g , gather_input_ids , output_type , output_shape , trip_count_input_ids ,
159
- rank , loop_node .name )
160
- loop_node .set_body_graph_as_attr ("body" , loop_body )
159
+ skip_conversion = False , branches = branches )
161
160
return loop_node
162
161
163
162
@@ -223,8 +222,9 @@ def make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dt
223
222
224
223
# loop
225
224
loop_inputs = [trip_count_node .output [0 ], cond_name , start ]
226
- loop_node = ctx .make_node ("Loop" , loop_inputs , output_count = 2 , op_name_scope = base_name , name = "loop" )
227
- loop_node .set_body_graph_as_attr ("body" , g )
225
+ branches = {"body" : g }
226
+ loop_node = ctx .make_node ("Loop" , loop_inputs ,
227
+ output_count = 2 , op_name_scope = base_name , name = "loop" , branches = branches )
228
228
229
229
ctx .make_node ("Identity" , [loop_node .output [1 ]], name = base_name , shapes = [shape ], dtypes = [dtype ], outputs = [output ])
230
230
@@ -404,15 +404,16 @@ def version_1(cls, ctx, node, **kwargs):
404
404
ctx .remove_node (node .name )
405
405
406
406
# replace the original node
407
- if_node = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
408
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
409
-
407
+ branches = {}
410
408
for branch in ["then_branch" , "else_branch" ]:
411
409
func_name = node .get_attr_str (branch )
412
410
g = find_function (func_name )
413
411
g .parent_graph = ctx
414
412
wire_if_branch (ctx , g , inputs , output_shapes , output_dtypes , func_name , node .name )
415
- if_node .set_body_graph_as_attr (branch , g )
413
+ branches [branch ] = g
414
+
415
+ _ = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
416
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True , branches = branches )
416
417
417
418
418
419
@tf_op (["If" ])
@@ -431,15 +432,16 @@ def version_1(cls, ctx, node, **kwargs):
431
432
ctx .remove_node (node .name )
432
433
433
434
# replace the original node
434
- if_node = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
435
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
436
-
435
+ branches = {}
437
436
for branch in ["then_branch" , "else_branch" ]:
438
437
func_name = node .get_attr_str (branch )
439
438
g = find_function (func_name )
440
439
g .parent_graph = ctx
441
440
wire_if_branch (ctx , g , inputs , output_shapes , output_dtypes , func_name , node .name )
442
- if_node .set_body_graph_as_attr (branch , g )
441
+ branches [branch ] = g
442
+
443
+ _ = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
444
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True , branches = branches )
443
445
444
446
445
447
@tf_op (["TensorListSetItem" ])
@@ -610,9 +612,11 @@ def version_7(cls, ctx, node, **kwargs):
610
612
output_dtypes = output_dtypes [2 :]
611
613
output_names = output_names [2 :]
612
614
615
+ branches = {"body" : body }
613
616
loop_node = ctx .make_node ("Loop" , [maximum_iterations_name , cond_outputs [0 ]] + loop_vars ,
614
617
output_count = len (output_shapes ), name = node .name + "_loop" ,
615
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
618
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
619
+ branches = branches )
616
620
617
621
output_map = dict (zip (output_names , loop_node .output ))
618
622
@@ -628,7 +632,6 @@ def version_7(cls, ctx, node, **kwargs):
628
632
for i , n in enumerate (body .inputs ):
629
633
if body .get_dtype (n .output [0 ]) == onnx_pb .TensorProto .UNDEFINED :
630
634
body .set_dtype (n .output [0 ], ctx .get_dtype (loop_node .input [i ]))
631
- loop_node .set_body_graph_as_attr ("body" , body )
632
635
633
636
634
637
def wire_while_body (parent_g , g , loop_node_inputs , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
@@ -801,13 +804,14 @@ def prefix_graph(g, scope):
801
804
attr = node .attr
802
805
if node .is_graph_input ():
803
806
continue
804
- new_node = g .make_node (node .type , node .input , name = node .name , output_count = len (node .output ),
805
- shapes = output_shapes , dtypes = output_dtypes , attr = attr ,
806
- op_name_scope = scope , skip_conversion = True )
807
+ branches = {}
807
808
attr_graphs = node .get_body_graphs ()
808
809
if attr_graphs :
809
810
for k , v in attr_graphs .items ():
810
- new_node .set_body_graph_as_attr (k , v )
811
+ branches [k ] = v
812
+ new_node = g .make_node (node .type , node .input , name = node .name , output_count = len (node .output ),
813
+ shapes = output_shapes , dtypes = output_dtypes , attr = attr ,
814
+ op_name_scope = scope , skip_conversion = True , branches = branches )
811
815
for old_output , new_output in zip (node .output , new_node .output ):
812
816
for i , oname in enumerate (g .outputs ):
813
817
if old_output == oname :
0 commit comments