@@ -33,7 +33,7 @@ def get_inputs_for_current_iteration(g, input_id, iter_index):
33
33
34
34
35
35
def create_loop_body_graph (parent_g , gather_input_ids , output_data_type , output_shape , trip_count_input_ids ,
36
- rank , loop_name ):
36
+ rank ):
37
37
g = parent_g .create_new_graph_with_same_config ()
38
38
g .parent_graph = parent_g
39
39
iter_name = utils .make_name ("i" )
@@ -111,9 +111,9 @@ def create_if_op(g, input_ids, output_data_type, output_shape):
111
111
out_name = utils .port_name (op_name )
112
112
113
113
# output a scalar
114
- if_node = g . make_node ( "If" , [ input_ids [ 0 ]], outputs = [ out_name ], name = op_name , skip_conversion = True )
115
- if_node . set_body_graph_as_attr ( "then_branch " , true_graph )
116
- if_node . set_body_graph_as_attr ( "else_branch" , false_graph )
114
+ branches = { "then_branch" : true_graph , "else_branch" : false_graph }
115
+ if_node = g . make_node ( "If " , [ input_ids [ 0 ]], outputs = [ out_name ], name = op_name ,
116
+ skip_conversion = True , branches = branches )
117
117
return if_node , out_name
118
118
119
119
@@ -151,12 +151,11 @@ def create_loop_op(g, gather_input_ids, output_type, output_shape, trip_count_in
151
151
cond_var_name , # termination condition
152
152
fake_val_name # initial value of loop-carried dependencies
153
153
]
154
+ loop_body = create_loop_body_graph (g , gather_input_ids , output_type , output_shape , trip_count_input_ids , rank )
154
155
# define an extra scan output
156
+ branches = {"body" : loop_body }
155
157
loop_node = g .make_node ("Loop" , loop_inputs , output_count = 2 , op_name_scope = "select_loop" ,
156
- skip_conversion = False )
157
- loop_body = create_loop_body_graph (g , gather_input_ids , output_type , output_shape , trip_count_input_ids ,
158
- rank , loop_node .name )
159
- loop_node .set_body_graph_as_attr ("body" , loop_body )
158
+ skip_conversion = False , branches = branches )
160
159
return loop_node
161
160
162
161
@@ -222,8 +221,9 @@ def make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dt
222
221
223
222
# loop
224
223
loop_inputs = [trip_count_node .output [0 ], cond_name , start ]
225
- loop_node = ctx .make_node ("Loop" , loop_inputs , output_count = 2 , op_name_scope = base_name , name = "loop" )
226
- loop_node .set_body_graph_as_attr ("body" , g )
224
+ branches = {"body" : g }
225
+ loop_node = ctx .make_node ("Loop" , loop_inputs ,
226
+ output_count = 2 , op_name_scope = base_name , name = "loop" , branches = branches )
227
227
228
228
ctx .make_node ("Identity" , [loop_node .output [1 ]], name = base_name , shapes = [shape ], dtypes = [dtype ], outputs = [output ])
229
229
@@ -409,15 +409,16 @@ def version_1(cls, ctx, node, **kwargs):
409
409
ctx .remove_node (node .name )
410
410
411
411
# replace the original node
412
- if_node = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
413
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
414
-
412
+ branches = {}
415
413
for branch in ["then_branch" , "else_branch" ]:
416
414
func_name = node .get_attr_str (branch )
417
415
g = find_function (func_name )
418
416
g .parent_graph = ctx
419
417
wire_if_branch (ctx , g , inputs , output_shapes , output_dtypes , func_name , node .name )
420
- if_node .set_body_graph_as_attr (branch , g )
418
+ branches [branch ] = g
419
+
420
+ _ = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
421
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True , branches = branches )
421
422
422
423
423
424
@tf_op (["If" ])
@@ -436,15 +437,16 @@ def version_1(cls, ctx, node, **kwargs):
436
437
ctx .remove_node (node .name )
437
438
438
439
# replace the original node
439
- if_node = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
440
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
441
-
440
+ branches = {}
442
441
for branch in ["then_branch" , "else_branch" ]:
443
442
func_name = node .get_attr_str (branch )
444
443
g = find_function (func_name )
445
444
g .parent_graph = ctx
446
445
wire_if_branch (ctx , g , inputs , output_shapes , output_dtypes , func_name , node .name )
447
- if_node .set_body_graph_as_attr (branch , g )
446
+ branches [branch ] = g
447
+
448
+ _ = ctx .make_node ("If" , node .input [:1 ], name = node .name , output_count = len (output_shapes ),
449
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True , branches = branches )
448
450
449
451
450
452
@tf_op (["TensorListSetItem" ])
@@ -615,9 +617,11 @@ def version_7(cls, ctx, node, **kwargs):
615
617
output_dtypes = output_dtypes [2 :]
616
618
output_names = output_names [2 :]
617
619
620
+ branches = {"body" : body }
618
621
loop_node = ctx .make_node ("Loop" , [maximum_iterations_name , cond_outputs [0 ]] + loop_vars ,
619
622
output_count = len (output_shapes ), name = node .name + "_loop" ,
620
- shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True )
623
+ shapes = output_shapes , dtypes = output_dtypes , skip_conversion = True ,
624
+ branches = branches )
621
625
622
626
output_map = dict (zip (output_names , loop_node .output ))
623
627
@@ -633,7 +637,6 @@ def version_7(cls, ctx, node, **kwargs):
633
637
for i , n in enumerate (body .inputs ):
634
638
if body .get_dtype (n .output [0 ]) == onnx_pb .TensorProto .UNDEFINED :
635
639
body .set_dtype (n .output [0 ], ctx .get_dtype (loop_node .input [i ]))
636
- loop_node .set_body_graph_as_attr ("body" , body )
637
640
638
641
639
642
def wire_while_body (parent_g , g , loop_node_inputs , body_input_to_state_var , cond_input_to_state_var , output_shapes ,
@@ -806,13 +809,14 @@ def prefix_graph(g, scope):
806
809
attr = node .attr
807
810
if node .is_graph_input ():
808
811
continue
809
- new_node = g .make_node (node .type , node .input , name = node .name , output_count = len (node .output ),
810
- shapes = output_shapes , dtypes = output_dtypes , attr = attr ,
811
- op_name_scope = scope , skip_conversion = True )
812
+ branches = {}
812
813
attr_graphs = node .get_body_graphs ()
813
814
if attr_graphs :
814
815
for k , v in attr_graphs .items ():
815
- new_node .set_body_graph_as_attr (k , v )
816
+ branches [k ] = v
817
+ new_node = g .make_node (node .type , node .input , name = node .name , output_count = len (node .output ),
818
+ shapes = output_shapes , dtypes = output_dtypes , attr = attr ,
819
+ op_name_scope = scope , skip_conversion = True , branches = branches )
816
820
for old_output , new_output in zip (node .output , new_node .output ):
817
821
for i , oname in enumerate (g .outputs ):
818
822
if old_output == oname :
0 commit comments