8
8
from __future__ import division
9
9
from __future__ import print_function
10
10
import logging
11
- from onnx import TensorProto
12
11
13
- from tf2onnx .graph_builder import GraphBuilder
14
12
from tf2onnx .rewriter .loop_rewriter_base import LoopRewriterBase , Context
15
13
from tf2onnx .rewriter .rnn_utils import REWRITER_RESULT , get_pattern , \
16
14
get_rnn_scope_name , parse_rnn_loop , seq_len_pattern
@@ -36,7 +34,9 @@ def __init__(self):
36
34
self .hidden_size = None
37
35
38
36
self .attributes = {} # onnx attributes
39
- self .onnx_input_ids = {} # onnx inputs: [X, W, R, B, sequence_lens, initial_h, initial_c, P]
37
+ # onnx inputs: [X, W, R, B, sequence_lens, initial_h, initial_c, P],
38
+ # sequence_lens is optional, i.e., None
39
+ self .onnx_input_ids = {}
40
40
41
41
42
42
class UnitRnnRewriterBase (LoopRewriterBase ):
@@ -114,7 +114,9 @@ def parse_unit_rnn(self, context):
114
114
seq_len_node = self .find_sequence_length_node (context )
115
115
if seq_len_node :
116
116
logger .debug ("find sequence node: %s" , seq_len_node .name )
117
- context .seq_len_node = seq_len_node
117
+ context .onnx_input_ids ["sequence_lens" ] = seq_len_node .output [0 ]
118
+ else :
119
+ context .onnx_input_ids ["sequence_lens" ] = None
118
120
119
121
# require exact one input
120
122
inputs = context .loop_properties .scan_inputs_initial_values
@@ -166,8 +168,6 @@ def rewrite(self, context):
166
168
logger .debug ("process the weights/bias/ft_bias, to fit onnx weights/bias requirements" )
167
169
self .process_weights_and_bias (context )
168
170
169
- self .process_seq_length (context )
170
-
171
171
self .process_var_init_nodes (context )
172
172
173
173
logger .debug ("start to build new rnn node" )
@@ -222,40 +222,6 @@ def find_sequence_length_node(self, context):
222
222
def process_weights_and_bias (self , context ):
223
223
raise NotImplementedError ()
224
224
225
- def process_seq_length (self , context ):
226
- # output: [time step, batch size, input size]
227
- seq_len_node = context .seq_len_node
228
- shape_node = self .g .make_node ("Shape" , [context .onnx_input_ids ["X" ]])
229
- # LSTMCell only allow inputs of [batch size, input_size], so we assume dynamic_rnn has 3 dims.
230
- # Slice cannot support Int64 in OPSET 7, so we cast here.
231
- cast_shape_node = self .g .make_node (
232
- "Cast" , [shape_node .output [0 ]],
233
- attr = {"to" : TensorProto .FLOAT },
234
- shapes = [self .g .get_shape (shape_node .output [0 ])]
235
- )
236
-
237
- attr = {"axes" : [0 ], "starts" : [1 ], "ends" : [2 ]}
238
- inputs_map = {"data" : cast_shape_node .output [0 ], ** attr }
239
- batchsize_node = GraphBuilder (self .g ).make_slice (inputs_map )
240
- if not seq_len_node :
241
- # Tile's repeats must be INT64
242
- repeat_node = self .g .make_node (
243
- "Cast" , [batchsize_node ],
244
- attr = {"to" : TensorProto .INT64 }
245
- )
246
-
247
- attr = {"axes" : [0 ], "starts" : [0 ], "ends" : [1 ]}
248
- inputs_map = {"data" : cast_shape_node .output [0 ], ** attr }
249
- timestep_node = GraphBuilder (self .g ).make_slice (inputs_map )
250
- tile_node = self .g .make_node ("Tile" , [timestep_node , repeat_node .output [0 ]])
251
-
252
- # LSTM sequence_lens needs to be int32
253
- seq_len_node = self .g .make_node (
254
- "Cast" , [tile_node .output [0 ]],
255
- attr = {"to" : TensorProto .INT32 }
256
- )
257
- context .onnx_input_ids ["sequence_lens" ] = seq_len_node .output [0 ]
258
-
259
225
def process_var_init_nodes (self , context ):
260
226
raise NotImplementedError ()
261
227
0 commit comments