@@ -45,21 +45,19 @@ def rewrite(self, context):
45
45
cell_g_info = context .cell_graph
46
46
cond_g_info = context .cond_graph
47
47
48
- # todo(pengwa): we don't check the case where loop body won't be executed at all.
48
+ # create a dummy loop to calculate the init condition
49
+ init_cond_output = self ._create_subgraph_initial_cond (cond_g_info )
49
50
50
51
## create Loop body graph with existing nodes
51
52
52
- # replace condition graph's inputs to be cell graph's outputs, because we want condition graph
53
- # to consumer cell graph outputs.
54
- for loop_var in cond_g_info .dependent_vars :
55
- self .g .replace_all_inputs (cond_g_info .nodes , loop_var .switch_true_identity_output .id ,
56
- loop_var .next_iteration_input .id )
57
-
58
53
body_nodes = set (cell_g_info .nodes + cond_g_info .nodes )
59
54
body_outputs = cond_g_info .outputs + cell_g_info .outputs
60
55
for out_tensor_value_info in body_outputs :
61
56
shape = out_tensor_value_info .shape
62
- utils .make_sure (shape is not None , "Shape of {} is None" .format (out_tensor_value_info .id ))
57
+ utils .make_sure (
58
+ shape is not None ,
59
+ "Conversion of Loop requries output shape [{}] exists" .format (out_tensor_value_info .id )
60
+ )
63
61
out_tensor_value_info .shape = utils .create_vague_shape_like (shape )
64
62
65
63
loop_body_g = LoopRewriterBase .construct_graph_from_nodes (self .g , body_nodes , body_outputs )
@@ -90,7 +88,7 @@ def rewrite(self, context):
90
88
loop_body_g .replace_all_inputs (loop_body_g .get_nodes (), input_ta .consumer .id , data_node .output [0 ])
91
89
92
90
## create Loop node
93
- loop_node = self ._create_loop_node (context , loop_props )
91
+ loop_node = self ._create_loop_node (context , loop_props , init_cond_output )
94
92
if not loop_node :
95
93
logger .error ("failed to create loop node during rewrite" )
96
94
return REWRITER_RESULT .FAIL
@@ -104,7 +102,48 @@ def rewrite(self, context):
104
102
logger .error ("loop rewrite failed, due to exception: %s, details:%s" , ex , tb )
105
103
return REWRITER_RESULT .FAIL
106
104
107
- def _create_loop_node (self , context , loop_props ):
105
+ def _create_subgraph_initial_cond (self , cond_graph ):
106
+ """Create subgraph to calculate initial cond."""
107
+ # copy condition subgraph to parent graph
108
+ copied_nodes = []
109
+ name_scope = utils .make_name ("copy" )
110
+ for node in cond_graph .nodes :
111
+ new_name = "{}/{}" .format (name_scope , node .name )
112
+ new_outputs = ["{}/{}" .format (name_scope , out ) for out in node .output ]
113
+ # some inputs are out of cond_graph.nodes, keep them intact
114
+ new_inputs = []
115
+ for inp in node .input :
116
+ if self .g .get_node_by_output (inp ) in cond_graph .nodes :
117
+ new_inputs .append ("{}/{}" .format (name_scope , inp ))
118
+ else :
119
+ new_inputs .append (inp )
120
+
121
+ new_node = self .g .make_node (
122
+ node .type , new_inputs , outputs = new_outputs ,
123
+ attr = node .attr , name = new_name ,
124
+ shapes = node .output_shapes , dtypes = node .output_dtypes ,
125
+ skip_conversion = node .skip_conversion , infer_shape_dtype = False
126
+ )
127
+ body_graphs = node .graph .contained_graphs .pop (node .name , None )
128
+ if body_graphs :
129
+ for attr_name , body_graph in body_graphs .items ():
130
+ body_graph .parent_graph = g
131
+ new_node .set_body_graph_as_attr (attr_name , body_graph )
132
+ copied_nodes .append (new_node )
133
+
134
+ # replace all inputs of condition graph by initializer (enter_input)
135
+ for loop_var in cond_graph .dependent_vars :
136
+ self .g .replace_all_inputs (
137
+ copied_nodes ,
138
+ loop_var .next_iteration_input .id ,
139
+ loop_var .enter_input_id
140
+ )
141
+ init_cond_output = "{}/{}" .format (name_scope , cond_graph .outputs [0 ].id )
142
+ self .g .set_dtype (init_cond_output , cond_graph .outputs [0 ].dtype )
143
+ self .g .set_shape (init_cond_output , cond_graph .outputs [0 ].shape )
144
+ return init_cond_output
145
+
146
+ def _create_loop_node (self , context , loop_props , init_cond_output ):
108
147
loop_outputs = []
109
148
loop_output_shapes = []
110
149
loop_output_dtypes = []
@@ -123,8 +162,7 @@ def _create_loop_node(self, context, loop_props):
123
162
# trip count and cond are not used, giving them values just because bug
124
163
# (https://github.com/Microsoft/onnxruntime/issues/255) of onnxruntime.
125
164
trip_cnt = self .g .make_const (utils .make_name ("trip_count" ), np .array (sys .maxsize , dtype = np .int64 ))
126
- cond = self .g .make_const (utils .make_name ("cond" ), np .array (True , dtype = np .bool ))
127
- loop_node = self .g .make_node ("Loop" , [trip_cnt .output [0 ]] + [cond .output [0 ]] +
165
+ loop_node = self .g .make_node ("Loop" , [trip_cnt .output [0 ]] + [init_cond_output ] +
128
166
loop_props .state_inputs_initial_values , # ONNX Loop support state inputs only
129
167
outputs = loop_outputs , op_name_scope = "generic_loop" ,
130
168
shapes = loop_output_shapes , dtypes = loop_output_dtypes ,
0 commit comments