@@ -105,13 +105,19 @@ def rewrite(self, context):
105
105
106
106
state_inputs_initial_values = []
107
107
for state_input in scan_props .state_inputs_initial_values :
108
- nodes = self ._adapt_scan_sequence_input_or_output ("input" , state_input , False )
109
- state_inputs_initial_values .append (nodes [- 1 ].output [0 ])
108
+ if self .g .opset == 8 :
109
+ nodes = self ._adapt_scan_sequence_input_or_output ("input" , state_input , False )
110
+ state_inputs_initial_values .append (nodes [- 1 ].output [0 ])
111
+ else :
112
+ state_inputs_initial_values .append (state_input )
110
113
111
114
scan_inputs_initial_values = []
112
115
for scan_input in scan_props .scan_inputs_initial_values :
113
- nodes = self ._adapt_scan_sequence_input_or_output ("input" , scan_input , False )
114
- scan_inputs_initial_values .append (nodes [- 1 ].output [0 ])
116
+ if self .g .opset == 8 :
117
+ nodes = self ._adapt_scan_sequence_input_or_output ("input" , scan_input , False )
118
+ scan_inputs_initial_values .append (nodes [- 1 ].output [0 ])
119
+ else :
120
+ scan_inputs_initial_values .append (scan_input )
115
121
116
122
cell_g_info = context .cell_graph
117
123
scan_body_g = LoopRewriterBase .construct_graph_from_nodes (self .g , cell_g_info .nodes , cell_g_info .outputs )
@@ -155,17 +161,24 @@ def _create_scan_node(self, context, scan_props, init_values):
155
161
n = self .g .get_node_by_output (tensor_value_info .id )
156
162
self .g .remove_node (n .name )
157
163
else :
158
- loop_outputs_shapes .append (None )
164
+ loop_outputs_shapes .append ([ - 1 ] )
159
165
loop_outputs_dtypes .append (None )
160
166
161
- # here we did not give the sequence_length, because
162
- # current batch size is 1, not original batch size
163
- # original seq_length will be used by the loop body of Scan op.
164
- scan_node = self .g .make_node ("Scan" , ["" ] + init_values , op_name_scope = "custom_rnn_scan" ,
165
- attr = {"num_scan_inputs" : len (scan_props .scan_inputs )},
166
- output_count = len (scan_props .state_outputs + scan_props .scan_outputs ),
167
- shapes = loop_outputs_shapes , dtypes = loop_outputs_dtypes ,
168
- skip_conversion = False )
167
+ if self .g .opset == 8 :
168
+ # here we did not give the sequence_length, because
169
+ # current batch size is 1, not original batch size
170
+ # original seq_length will be used by the loop body of Scan op.
171
+ scan_node = self .g .make_node ("Scan" , ["" ] + init_values , op_name_scope = "custom_rnn_scan" ,
172
+ attr = {"num_scan_inputs" : len (scan_props .scan_inputs )},
173
+ output_count = len (scan_props .state_outputs + scan_props .scan_outputs ),
174
+ shapes = loop_outputs_shapes , dtypes = loop_outputs_dtypes ,
175
+ skip_conversion = False )
176
+ else :
177
+ scan_node = self .g .make_node ("Scan" , init_values , op_name_scope = "custom_rnn_scan" ,
178
+ attr = {"num_scan_inputs" : len (scan_props .scan_inputs )},
179
+ output_count = len (scan_props .state_outputs + scan_props .scan_outputs ),
180
+ shapes = loop_outputs_shapes , dtypes = loop_outputs_dtypes ,
181
+ skip_conversion = False )
169
182
170
183
return scan_node
171
184
@@ -175,17 +188,22 @@ def _connect_scan_with_output(self, context, scan_node):
175
188
index = 0
176
189
for out_tensor_value_info in context .loop_properties .state_outputs_exits :
177
190
if out_tensor_value_info .id :
178
- nodes = self ._adapt_scan_sequence_input_or_output ("state_output_reshape" ,
179
- scan_node .output [index ], True )
180
- self .g .replace_all_inputs (self .g .get_nodes (), out_tensor_value_info .id , nodes [- 1 ].output [0 ])
181
-
191
+ if self .g .opset == 8 :
192
+ nodes = self ._adapt_scan_sequence_input_or_output ("state_output_reshape" ,
193
+ scan_node .output [index ], True )
194
+ self .g .replace_all_inputs (self .g .get_nodes (), out_tensor_value_info .id , nodes [- 1 ].output [0 ])
195
+ else :
196
+ self .g .replace_all_inputs (self .g .get_nodes (), out_tensor_value_info .id , scan_node .output [index ])
182
197
index += 1
183
198
184
199
for out_tensor_value_info in context .loop_properties .scan_outputs_exits :
185
200
if out_tensor_value_info .id :
186
- nodes = self ._adapt_scan_sequence_input_or_output ("scan_output_reshape" ,
187
- scan_node .output [index ], True )
188
- self .g .replace_all_inputs (self .g .get_nodes (), out_tensor_value_info .id , nodes [- 1 ].output [0 ])
201
+ if self .g .opset == 8 :
202
+ nodes = self ._adapt_scan_sequence_input_or_output ("scan_output_reshape" ,
203
+ scan_node .output [index ], True )
204
+ self .g .replace_all_inputs (self .g .get_nodes (), out_tensor_value_info .id , nodes [- 1 ].output [0 ])
205
+ else :
206
+ self .g .replace_all_inputs (self .g .get_nodes (), out_tensor_value_info .id , scan_node .output [index ])
189
207
index += 1
190
208
191
209
0 commit comments