@@ -40,14 +40,14 @@ def find_cell(self, context):
40
40
cell_match = self ._match_cell (context , cell_type )
41
41
if cell_match and len (cell_match ) >= 1 :
42
42
self .num_lstm_layers = len (cell_match )
43
- logger .debug ("number of lstm layers: " + str ( self .num_lstm_layers ) )
43
+ logger .debug ("number of LSTM layers: %s" , self .num_lstm_layers )
44
44
for i in range (self .num_lstm_layers ):
45
45
self .state_variable_handlers .append ({
46
46
"ct" + str (i ): (self ._ct_variable_finder , self ._connect_lstm_yc_to_graph , i ),
47
47
"ht" + str (i ): (self ._ht_variable_finder , self ._connect_lstm_yh_to_graph , i )
48
48
})
49
49
self .state_variable_handlers .append ({
50
- "ct_ht" + str (i ): (self ._ct_ht_shared_variable_finder , self ._connect_lstm_ych_to_graph , i )
50
+ "ct_ht" + str (i ): (self ._ct_ht_shared_variable_finder , self ._connect_lstm_ych_to_graph , i )
51
51
})
52
52
logger .debug ("parsing unit is %s, num layers is %d" , cell_type , self .num_lstm_layers )
53
53
if cell_match :
@@ -287,9 +287,9 @@ def process_var_init_nodes(self, context):
287
287
def process_var_init_nodes_per_layer (self , context , i ):
288
288
init_h_id = None
289
289
init_c_id = None
290
- if ( "ct_ht" + str (i ) ) in context .state_variables :
290
+ if "ct_ht" + str (i ) in context .state_variables :
291
291
init_h_id , init_c_id = self ._process_non_tuple_ch_init_nodes (context , i )
292
- elif ( "ct" + str (i ) ) in context .state_variables and ("ht" + str (i )) in context .state_variables :
292
+ elif "ct" + str (i ) in context .state_variables and ("ht" + str (i )) in context .state_variables :
293
293
init_h_id , init_c_id = self ._process_tuple_ch_init_nodes (context , i )
294
294
else :
295
295
raise ValueError ("no initializers, unexpected" )
@@ -363,11 +363,10 @@ def create_single_rnn_node(self, context, i):
363
363
def create_rnn_node (self , context ):
364
364
rnn_nodes = list ()
365
365
outputs = context .loop_properties .scan_outputs_exits
366
- logger .debug ("number of rnn node outputs:" + str (len (outputs )))
367
- for i in range (len (outputs )):
368
- logger .debug ("output " + str (i ) + " with id=" + outputs [i ].id )
366
+ logger .debug ("number of rnn node outputs: %s" , len (outputs ))
367
+
369
368
for i in range (self .num_lstm_layers ):
370
- logger .debug ("creating rnn node for layer: " + str ( i ) )
369
+ logger .debug ("creating rnn node for layer: %s" , i )
371
370
rnn_nodes .append (self .create_single_rnn_node (context , i ))
372
371
output_id = rnn_nodes [i ].output [0 ]
373
372
rnn_output_shape = self .g .get_shape (output_id )
@@ -376,7 +375,7 @@ def create_rnn_node(self, context):
376
375
shapes = [squeeze_output_shape ],
377
376
dtypes = [self .g .get_dtype (output_id )])
378
377
if i + 1 < self .num_lstm_layers :
379
- logger .debug ("setting input for layer: " + str ( i + 1 ) )
378
+ logger .debug ("setting input for layer: %s" , i + 1 )
380
379
context .onnx_input_ids [i + 1 ]["X" ] = squeeze_node .output [0 ]
381
380
return rnn_nodes
382
381
@@ -420,4 +419,4 @@ def _connect_lstm_ych_to_graph(self, context, i):
420
419
shapes = [squeeze_output_shape ],
421
420
dtypes = [self .g .get_dtype (concat .output [0 ])])
422
421
423
- self .g .replace_all_inputs (self .g .get_nodes (), exit_output .id , squeeze_node .output [0 ])
422
+ self .g .replace_all_inputs (self .g .get_nodes (), exit_output .id , squeeze_node .output [0 ])
0 commit comments