Skip to content

Commit 972adea

Browse files
authored
Merge pull request #357 from lucienwang1009/lstm_block_cell
Support LSTMBlockCell in tf2onnx mapping
2 parents ad365f6 + d954d43 commit 972adea

File tree

4 files changed

+333
-1
lines changed

4 files changed

+333
-1
lines changed

tests/test_lstmblock.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""Unit Tests for lstm block cell."""
5+
6+
from __future__ import absolute_import
7+
from __future__ import division
8+
from __future__ import print_function
9+
from __future__ import unicode_literals
10+
11+
import numpy as np
12+
import tensorflow as tf
13+
14+
from tensorflow.contrib import rnn
15+
from backend_test_base import Tf2OnnxBackendTestBase
16+
from common import unittest_main, check_tf_min_version, check_opset_min_version
17+
18+
19+
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
20+
21+
22+
class LSTMBlockTests(Tf2OnnxBackendTestBase):
23+
@check_opset_min_version(8, "Scan")
24+
def test_single_dynamic_lstm(self):
25+
units = 5
26+
batch_size = 6
27+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
28+
x_val = np.stack([x_val] * batch_size)
29+
30+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
31+
32+
# no scope
33+
cell = rnn.LSTMBlockCell(units, use_peephole=False)
34+
outputs, cell_state = tf.nn.dynamic_rnn(
35+
cell,
36+
x,
37+
dtype=tf.float32)
38+
39+
_ = tf.identity(outputs, name="output")
40+
_ = tf.identity(cell_state, name="cell_state")
41+
42+
input_names_with_port = ["input_1:0"]
43+
feed_dict = {"input_1:0": x_val}
44+
45+
output_names_with_port = ["output:0", "cell_state:0"]
46+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06, atol=1e-07)
47+
48+
# ==============================================================================================
49+
# NOTE: the unittest above should be converted into a single LSTM op, while following unittests
50+
# should be first converted into a Scan op with LSTMBlockCell, then decoupled into several ops.
51+
# ==============================================================================================
52+
53+
@check_opset_min_version(8, "Scan")
54+
def test_single_dynamic_lstm_with_peephole(self):
55+
units = 5
56+
batch_size = 6
57+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
58+
x_val = np.stack([x_val] * batch_size)
59+
60+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
61+
62+
# no scope
63+
cell = rnn.LSTMBlockCell(units, use_peephole=True)
64+
outputs, cell_state = tf.nn.dynamic_rnn(
65+
cell,
66+
x,
67+
dtype=tf.float32)
68+
69+
_ = tf.identity(outputs, name="output")
70+
_ = tf.identity(cell_state, name="cell_state")
71+
72+
input_names_with_port = ["input_1:0"]
73+
feed_dict = {"input_1:0": x_val}
74+
75+
output_names_with_port = ["output:0", "cell_state:0"]
76+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06, atol=1e-07)
77+
78+
@check_opset_min_version(8, "Scan")
79+
def test_single_dynamic_lstm_with_cell_clip(self):
80+
units = 5
81+
batch_size = 6
82+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
83+
x_val = np.stack([x_val] * batch_size)
84+
85+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
86+
87+
# no scope
88+
cell = rnn.LSTMBlockCell(units, cell_clip=0.05)
89+
outputs, cell_state = tf.nn.dynamic_rnn(
90+
cell,
91+
x,
92+
dtype=tf.float32)
93+
94+
_ = tf.identity(outputs, name="output")
95+
_ = tf.identity(cell_state, name="cell_state")
96+
97+
input_names_with_port = ["input_1:0"]
98+
feed_dict = {"input_1:0": x_val}
99+
100+
output_names_with_port = ["output:0", "cell_state:0"]
101+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06, atol=1e-07)
102+
103+
@check_opset_min_version(8, "Scan")
104+
@check_tf_min_version("1.8")
105+
def test_attention_wrapper_lstm_encoder(self):
106+
size = 5
107+
time_step = 3
108+
input_size = 4
109+
attn_size = size
110+
111+
batch_size = 9
112+
113+
# shape [batch size, time step, size]
114+
# attention_state: usually the output of an RNN encoder.
115+
# This tensor should be shaped `[batch_size, max_time, ...]`
116+
encoder_time_step = time_step
117+
encoder_x_val = np.random.randn(encoder_time_step, input_size).astype('f')
118+
encoder_x_val = np.stack([encoder_x_val] * batch_size)
119+
encoder_x = tf.placeholder(tf.float32, encoder_x_val.shape, name="input_1")
120+
encoder_cell = rnn.LSTMBlockCell(size)
121+
output, attr_state = tf.nn.dynamic_rnn(encoder_cell, encoder_x, dtype=tf.float32)
122+
_ = tf.identity(output, name="output_0")
123+
attention_states = output
124+
attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(attn_size,
125+
attention_states)
126+
127+
match_input_fn = lambda curr_input, state: tf.concat([curr_input, state], axis=-1)
128+
cell = rnn.LSTMBlockCell(size)
129+
match_cell_fw = tf.contrib.seq2seq.AttentionWrapper(cell,
130+
attention_mechanism,
131+
attention_layer_size=attn_size,
132+
cell_input_fn=match_input_fn,
133+
output_attention=False)
134+
135+
decoder_time_step = 6
136+
decoder_x_val = np.random.randn(decoder_time_step, input_size).astype('f')
137+
decoder_x_val = np.stack([decoder_x_val] * batch_size)
138+
139+
decoder_x = tf.placeholder(tf.float32, decoder_x_val.shape, name="input_2")
140+
output, attr_state = tf.nn.dynamic_rnn(match_cell_fw, decoder_x, dtype=tf.float32)
141+
142+
_ = tf.identity(output, name="output")
143+
_ = tf.identity(attr_state.cell_state, name="final_state")
144+
145+
feed_dict = {"input_1:0": encoder_x_val, "input_2:0": decoder_x_val}
146+
input_names_with_port = ["input_1:0", "input_2:0"]
147+
output_names_with_port = ["output_0:0", "output:0", "final_state:0"]
148+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, 0.1)
149+
150+
@check_opset_min_version(8, "Scan")
151+
def test_multi_rnn_lstm(self):
152+
units = 5
153+
batch_size = 6
154+
x_val = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]], dtype=np.float32)
155+
x_val = np.stack([x_val] * batch_size)
156+
157+
x = tf.placeholder(tf.float32, x_val.shape, name="input_1")
158+
159+
cell_0 = rnn.LSTMBlockCell(units)
160+
161+
cell_1 = rnn.LSTMBlockCell(units)
162+
163+
cell_2 = rnn.LSTMBlockCell(units)
164+
165+
cells = rnn.MultiRNNCell([cell_0, cell_1, cell_2], state_is_tuple=True)
166+
outputs, cell_state = tf.nn.dynamic_rnn(cells,
167+
x,
168+
dtype=tf.float32)
169+
170+
_ = tf.identity(outputs, name="output")
171+
_ = tf.identity(cell_state, name="cell_state")
172+
173+
input_names_with_port = ["input_1:0"]
174+
feed_dict = {"input_1:0": x_val}
175+
176+
output_names_with_port = ["output:0", "cell_state:0"]
177+
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06, atol=1e-07)
178+
179+
if __name__ == '__main__':
180+
unittest_main()

tf2onnx/function/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,17 @@
77
from __future__ import unicode_literals
88

99
from .gathernd import gathernd_op
10+
from .lstm_block_cell import lstm_block_cell_op
1011
from .matrixbandpart import matrixbandpart_op
1112
from .range import range_op7
1213
from .select import select_op8
1314
from .sparse_softmax_cross_entropy_with_logits import sparse_softmax_cross_entropy_with_logits_op
1415

15-
__all__ = ["gathernd_op", "matrixbandpart_op", "range_op7", "select_op8", "sparse_softmax_cross_entropy_with_logits_op"]
16+
__all__ = [
17+
"gathernd_op",
18+
"lstm_block_cell_op",
19+
"matrixbandpart_op",
20+
"range_op7",
21+
"select_op8",
22+
"sparse_softmax_cross_entropy_with_logits_op"
23+
]

tf2onnx/function/lstm_block_cell.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
tf2onnx.tf2onnx - lstm block cell conversion
6+
"""
7+
import numpy as np
8+
from tf2onnx import utils
9+
10+
# pylint: disable=unused-argument
11+
12+
13+
def lstm_block_cell_op(ctx, node, name, args):
14+
"""
15+
Args:
16+
x: A `Tensor`. Must be one of the following types: `float32`.
17+
The input to the LSTM cell, shape (batch_size, num_inputs).
18+
cs_prev: A `Tensor`. Must have the same type as `x`.
19+
Value of the cell state at previous time step.
20+
h_prev: A `Tensor`. Must have the same type as `x`.
21+
Output of the previous cell at previous time step.
22+
w: A `Tensor`. Must have the same type as `x`. The weight matrix.
23+
wci: A `Tensor`. Must have the same type as `x`.
24+
The weight matrix for input gate peephole connection.
25+
wcf: A `Tensor`. Must have the same type as `x`.
26+
The weight matrix for forget gate peephole connection.
27+
wco: A `Tensor`. Must have the same type as `x`.
28+
The weight matrix for output gate peephole connection.
29+
b: A `Tensor`. Must have the same type as `x`. The bias vector.
30+
forget_bias: An optional `float`. Defaults to `1`. The forget gate bias.
31+
cell_clip: An optional `float`. Defaults to `-1` (no clipping).
32+
Value to clip the 'cs' value to. Disable by setting to negative value.
33+
use_peephole: An optional `bool`. Defaults to `False`.
34+
Whether to use peephole weights.
35+
name: A name for the operation (optional).
36+
Returns:
37+
A tuple of `Tensor` objects (i, cs, f, o, ci, co, h).
38+
i: A `Tensor`. Has the same type as `x`. The input gate.
39+
cs: A `Tensor`. Has the same type as `x`. The cell state before the tanh.
40+
f: A `Tensor`. Has the same type as `x`. The forget gate.
41+
o: A `Tensor`. Has the same type as `x`. The output gate.
42+
ci: A `Tensor`. Has the same type as `x`. The cell input.
43+
co: A `Tensor`. Has the same type as `x`. The cell after the tanh.
44+
h: A `Tensor`. Has the same type as `x`. The output h vector.
45+
```python
46+
xh = [x, h_prev]
47+
[i, ci, f, o] = xh * w + b
48+
f = f + forget_bias
49+
if not use_peephole:
50+
wci = wcf = wco = 0
51+
i = sigmoid(cs_prev .* wci + i)
52+
f = sigmoid(cs_prev .* wcf + f)
53+
ci = tanh(ci)
54+
cs = ci .* i + cs_prev .* f
55+
cs = clip(cs, cell_clip)
56+
o = sigmoid(cs * wco + o)
57+
co = tanh(cs)
58+
h = co .* o
59+
```
60+
"""
61+
nodes = []
62+
x, cs_prev, h_prev, w, wci, wcf, wco, b = node.input
63+
forget_bias = float(node.get_attr("forget_bias").f)
64+
cell_clip = float(node.get_attr("cell_clip").f)
65+
use_peephole = bool(node.get_attr("use_peephole").i)
66+
67+
def make_sigmoid(i, w, b):
68+
i_w_node = ctx.make_node("Mul", [i, w])
69+
i_w_b_node = ctx.make_node("Add", [i_w_node.output[0], b])
70+
output_node = ctx.make_node("Sigmoid", [i_w_b_node.output[0]])
71+
nodes.extend([i_w_node, i_w_b_node, output_node])
72+
return output_node.output[0]
73+
74+
# xh = [x, h]
75+
xh_node = ctx.make_node("Concat", [x, h_prev], attr={"axis": 1})
76+
77+
# i, ci, f, o = xh * w + b
78+
xh_w_node = ctx.make_node("MatMul", [xh_node.output[0], w])
79+
w_shape = ctx.get_shape(w)
80+
if len(w_shape) != 2 or w_shape[1] % 4 != 0:
81+
raise RuntimeError("shape of W of LSTMBlockCell {} should be times of 4".format(name))
82+
merged_output_node = ctx.make_node("Add", [xh_w_node.output[0], b])
83+
w_last_dim = int(w_shape[1] / 4)
84+
split = [w_last_dim] * 4
85+
split_output_node = ctx.make_node(
86+
"Split", [merged_output_node.output[0]],
87+
attr={"axis": 1, "split": split},
88+
output_count=4
89+
)
90+
i, ci, f, o = split_output_node.output
91+
92+
# f = f + forget_bias
93+
forget_bias_const = ctx.make_const(
94+
utils.make_name("{}__forget_bias".format(name)),
95+
np.array(forget_bias, dtype=np.float32)
96+
)
97+
f_node = ctx.make_node("Add", [f, forget_bias_const.output[0]])
98+
99+
if not use_peephole:
100+
zeros_const = ctx.make_const(
101+
utils.make_name("{}__zeros_const".format(name)),
102+
np.zeros([w_last_dim], dtype=np.float32)
103+
)
104+
nodes.append(zeros_const)
105+
wci = zeros_const.output[0]
106+
wcf = zeros_const.output[0]
107+
wco = zeros_const.output[0]
108+
109+
# i = sigmoid(cs_prev .* wci + i)
110+
i = make_sigmoid(cs_prev, wci, i)
111+
# f = sigmoid(cs_prev .* wcf + f)
112+
f = make_sigmoid(cs_prev, wcf, f_node.output[0])
113+
# ci = Tanh(ci)
114+
ci_node = ctx.make_node("Tanh", [ci])
115+
# cs = ci .* i + f .* cs_prev
116+
ci_i_node = ctx.make_node("Mul", [ci_node.output[0], i])
117+
cs_prev_f_node = ctx.make_node("Mul", [cs_prev, f])
118+
cs_node = ctx.make_node("Add", [ci_i_node.output[0], cs_prev_f_node.output[0]])
119+
cs = cs_node.output[0]
120+
# cs = clip(cs)
121+
if cell_clip > 0:
122+
cs_clip_node = ctx.make_node("Clip", [cs], attr={"max": cell_clip, "min": -cell_clip})
123+
nodes.append(cs_clip_node)
124+
cs = cs_clip_node.output[0]
125+
# o = cs * wco + o
126+
o = make_sigmoid(cs, wco, o)
127+
# co = Tanh(cs)
128+
co_node = ctx.make_node("Tanh", [cs])
129+
# h = co .* o
130+
h_node = ctx.make_node("Mul", [co_node.output[0], o])
131+
132+
def replace_output(old_output, new_output):
133+
ctx.replace_all_inputs(ctx.get_nodes(), old_output, new_output)
134+
ctx.copy_dtype(old_output, new_output)
135+
ctx.copy_shape(old_output, new_output)
136+
137+
replace_output(node.output[0], i)
138+
replace_output(node.output[1], cs)
139+
replace_output(node.output[2], f)
140+
replace_output(node.output[3], o)
141+
replace_output(node.output[4], ci_node.output[0])
142+
replace_output(node.output[5], co_node.output[0])
143+
replace_output(node.output[6], h_node.output[0])

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1662,6 +1662,7 @@ def where_op(ctx, node, name, args):
16621662
"Log": (direct_op, []),
16631663
"LogSoftmax": (direct_op, ["LogSoftmax"]),
16641664
"LRN": (lrn_op, []),
1665+
"LSTMBlockCell": (lstm_block_cell_op, []),
16651666
"LogicalAnd": (broadcast_op, ["And"]),
16661667
"LogicalNot": (direct_op, ["Not"]),
16671668
"LogicalOr": (broadcast_op, ["Or"]),

0 commit comments

Comments
 (0)