Skip to content

Commit a27312a

Browse files
authored
Merge pull request #532 from lucienwang1009/remove_seq_len
if seq_len node doesn't exist, set sequence_lens of RNN empty
2 parents 119f602 + 7a01190 commit a27312a

File tree

5 files changed

+19
-44
lines changed

5 files changed

+19
-44
lines changed

tf2onnx/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,8 +578,8 @@ def update_node_shape_dtype(self, node, override=False):
578578
# op needs the "Shape" value to infer output shape.
579579
initializers = []
580580
for i, inp in enumerate(node.inputs):
581-
if not inp:
582-
if logger.isEnabledFor(logging.VERBOSE):
581+
if inp is None:
582+
if logger.isEnabledFor(logging.INFO):
583583
logger.warning(
584584
"[%s] infer a inexistent node: [%s], please check the code",
585585
node.name, node.input[i]

tf2onnx/rewriter/gru_rewriter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,13 @@ def create_rnn_node(self, context):
182182
context.attributes["direction"] = "forward"
183183
context.attributes["hidden_size"] = context.hidden_size
184184
inputs = context.onnx_input_ids
185+
# sequence length is optional
186+
seq_len_input = utils.ONNX_EMPTY_INPUT
187+
if inputs["sequence_lens"]:
188+
seq_len_input = inputs["sequence_lens"]
185189
gru_inputs = [
186190
inputs["X"], inputs["W"], inputs["R"], inputs["B"],
187-
inputs["sequence_lens"], inputs["initial_state"]]
191+
seq_len_input, inputs["initial_state"]]
188192
x_shape = self.g.get_shape(gru_inputs[0])
189193
x_seq_length = x_shape[0]
190194
x_batch_size = x_shape[1]

tf2onnx/rewriter/lstm_rewriter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,9 +320,13 @@ def create_rnn_node(self, context):
320320
context.attributes["direction"] = "forward"
321321
context.attributes["hidden_size"] = context.hidden_size
322322
inputs = context.onnx_input_ids
323+
# sequence len input is optional
324+
seq_len_input = utils.ONNX_EMPTY_INPUT
325+
if inputs["sequence_lens"]:
326+
seq_len_input = inputs["sequence_lens"]
323327
lstm_inputs = [
324328
inputs["X"], inputs["W"], inputs["R"], inputs["B"],
325-
inputs["sequence_lens"], inputs["initial_h"], inputs["initial_c"]]
329+
seq_len_input, inputs["initial_h"], inputs["initial_c"]]
326330

327331
x_shape = self.g.get_shape(lstm_inputs[0])
328332
x_seq_length = x_shape[0]

tf2onnx/rewriter/unit_rnn_rewriter_base.py

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
from __future__ import division
99
from __future__ import print_function
1010
import logging
11-
from onnx import TensorProto
1211

13-
from tf2onnx.graph_builder import GraphBuilder
1412
from tf2onnx.rewriter.loop_rewriter_base import LoopRewriterBase, Context
1513
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT, get_pattern, \
1614
get_rnn_scope_name, parse_rnn_loop, seq_len_pattern
@@ -36,7 +34,9 @@ def __init__(self):
3634
self.hidden_size = None
3735

3836
self.attributes = {} # onnx attributes
39-
self.onnx_input_ids = {} # onnx inputs: [X, W, R, B, sequence_lens, initial_h, initial_c, P]
37+
# onnx inputs: [X, W, R, B, sequence_lens, initial_h, initial_c, P],
38+
# sequence_lens is optional, i.e., None
39+
self.onnx_input_ids = {}
4040

4141

4242
class UnitRnnRewriterBase(LoopRewriterBase):
@@ -114,7 +114,9 @@ def parse_unit_rnn(self, context):
114114
seq_len_node = self.find_sequence_length_node(context)
115115
if seq_len_node:
116116
logger.debug("find sequence node: %s", seq_len_node.name)
117-
context.seq_len_node = seq_len_node
117+
context.onnx_input_ids["sequence_lens"] = seq_len_node.output[0]
118+
else:
119+
context.onnx_input_ids["sequence_lens"] = None
118120

119121
# require exact one input
120122
inputs = context.loop_properties.scan_inputs_initial_values
@@ -166,8 +168,6 @@ def rewrite(self, context):
166168
logger.debug("process the weights/bias/ft_bias, to fit onnx weights/bias requirements")
167169
self.process_weights_and_bias(context)
168170

169-
self.process_seq_length(context)
170-
171171
self.process_var_init_nodes(context)
172172

173173
logger.debug("start to build new rnn node")
@@ -222,40 +222,6 @@ def find_sequence_length_node(self, context):
222222
def process_weights_and_bias(self, context):
223223
raise NotImplementedError()
224224

225-
def process_seq_length(self, context):
226-
# output: [time step, batch size, input size]
227-
seq_len_node = context.seq_len_node
228-
shape_node = self.g.make_node("Shape", [context.onnx_input_ids["X"]])
229-
# LSTMCell only allow inputs of [batch size, input_size], so we assume dynamic_rnn has 3 dims.
230-
# Slice cannot support Int64 in OPSET 7, so we cast here.
231-
cast_shape_node = self.g.make_node(
232-
"Cast", [shape_node.output[0]],
233-
attr={"to": TensorProto.FLOAT},
234-
shapes=[self.g.get_shape(shape_node.output[0])]
235-
)
236-
237-
attr = {"axes": [0], "starts": [1], "ends": [2]}
238-
inputs_map = {"data": cast_shape_node.output[0], **attr}
239-
batchsize_node = GraphBuilder(self.g).make_slice(inputs_map)
240-
if not seq_len_node:
241-
# Tile's repeats must be INT64
242-
repeat_node = self.g.make_node(
243-
"Cast", [batchsize_node],
244-
attr={"to": TensorProto.INT64}
245-
)
246-
247-
attr = {"axes": [0], "starts": [0], "ends": [1]}
248-
inputs_map = {"data": cast_shape_node.output[0], **attr}
249-
timestep_node = GraphBuilder(self.g).make_slice(inputs_map)
250-
tile_node = self.g.make_node("Tile", [timestep_node, repeat_node.output[0]])
251-
252-
# LSTM sequence_lens needs to be int32
253-
seq_len_node = self.g.make_node(
254-
"Cast", [tile_node.output[0]],
255-
attr={"to": TensorProto.INT32}
256-
)
257-
context.onnx_input_ids["sequence_lens"] = seq_len_node.output[0]
258-
259225
def process_var_init_nodes(self, context):
260226
raise NotImplementedError()
261227

tf2onnx/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(self, tensor_id, g):
9393

9494

9595
ONNX_UNKNOWN_DIMENSION = -1
96+
ONNX_EMPTY_INPUT = ""
9697

9798
# index for internally generated names
9899
INTERNAL_NAME = 1

0 commit comments

Comments
 (0)