Skip to content

Commit 02061e9

Browse files
Implement rewriter for LSTM nodes in tf2 (#1584)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent de67f20 commit 02061e9

File tree

6 files changed

+234
-16
lines changed

6 files changed

+234
-16
lines changed

tests/test_lstm.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorflow.python.ops import init_ops
1010
from tensorflow.python.ops import variable_scope
1111
from backend_test_base import Tf2OnnxBackendTestBase
12-
from common import unittest_main, check_opset_after_tf_version, skip_tf2, skip_tf_versions
12+
from common import unittest_main, check_opset_after_tf_version, skip_tf2, skip_tf_versions, check_op_count
1313

1414
from tf2onnx.tf_loader import is_tf2
1515

@@ -36,12 +36,22 @@
3636

3737
class LSTMTests(Tf2OnnxBackendTestBase):
3838

39-
def run_test_case(self, *args, **kwargs): #pylint: disable=arguments-differ
39+
def run_test_case(self, *args, require_lstm_count=1, **kwargs): #pylint: disable=arguments-differ
4040
# TF LSTM has an unknown dim
4141
tmp = self.config.allow_missing_shapes
4242
self.config.allow_missing_shapes = True
43+
def graph_validator(g):
44+
good = True
45+
if "graph_validator" in kwargs:
46+
good = good and kwargs["graph_validator"](g)
47+
if require_lstm_count is None or ":" not in g.outputs[0]:
48+
# Skip checks for tflite graphs (no ":" in outputs)
49+
return good
50+
good = good and check_op_count(g, "LSTM", require_lstm_count, disabled=False)
51+
good = good and check_op_count(g, "Loop", 0, disabled=False)
52+
return good
4353
try:
44-
super().run_test_case(*args, **kwargs)
54+
super().run_test_case(*args, graph_validator=graph_validator, **kwargs)
4555
finally:
4656
self.config.allow_missing_shapes = tmp
4757

@@ -385,7 +395,8 @@ def func(x):
385395
feed_dict = {"input_1:0": x_val}
386396
input_names_with_port = ["input_1:0"]
387397
output_names_with_port = ["output:0", "cell_state:0"]
388-
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
398+
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06,
399+
require_lstm_count=2)
389400

390401
@check_opset_after_tf_version("1.15", 8, "might need Scan")
391402
@skip_tf2() # Still failing likely due to inconsistent random number initialization

tf2onnx/graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,8 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
464464
# A list of index, output tuples of potential scan outputs in this graph
465465
# Used by the tflite while loop handler
466466
self.scan_outputs = []
467+
# Used by lstm_tf2_rewriter to indicate this subgraph is an LSTM cell
468+
self.lstm_rewriter_context = None
467469
self.func_inputs = []
468470
self.ragged_variant_list_reads = []
469471
self.ragged_variant_list_writes = []

tf2onnx/rewriter/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tf2onnx.rewriter.quantization_ops_rewriter import rewrite_quantize_and_dequantize
2222
from tf2onnx.rewriter.layer_normalization_rewriter import rewrite_layer_normalization
2323
from tf2onnx.rewriter.ragged_variant_shape_rewriter import rewrite_ragged_variant_shape
24+
from tf2onnx.rewriter.lstm_tf2_rewriter import rewriter_lstm_tf2
2425

2526

2627
__all__ = [
@@ -46,5 +47,6 @@
4647
"rewrite_quantize_and_dequantize",
4748
"rewrite_layer_normalization",
4849
"rewrite_conv_dilations",
49-
"rewrite_ragged_variant_shape"
50+
"rewrite_ragged_variant_shape",
51+
"rewriter_lstm_tf2"
5052
]

tf2onnx/rewriter/lstm_tf2_rewriter.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
4+
"""
5+
tf2onnx.rewriter.lstm_tf2_rewriter - Rewrites LSTM pattern used by tf2.
6+
"""
7+
8+
import numpy as np
9+
from tf2onnx.graph_matcher import GraphMatcher
10+
from tf2onnx.rewriter.rnn_utils import make_lstmcell_pattern
11+
from tf2onnx.tf_loader import find_function
12+
from tf2onnx.rewriter.lstm_rewriter_base import LSTMContext
13+
from tf2onnx.rewriter.lstm_rewriter import LSTMRewriter
14+
from tf2onnx.graph_builder import GraphBuilder
15+
from tf2onnx import utils
16+
17+
# pylint: disable=invalid-name,unused-argument,missing-docstring, unused-variable
18+
19+
20+
def rewriter_lstm_tf2(g, ops):
21+
pattern1 = make_lstmcell_pattern("Identity")
22+
23+
for pattern in [pattern1]:
24+
matcher = GraphMatcher(pattern, allow_reorder=False)
25+
match_results = list(matcher.match_ops(ops))
26+
for match_result in match_results:
27+
concat = match_result.get_op("xh")
28+
if len(concat.inputs) != 3:
29+
continue
30+
get_item = concat.inputs[0]
31+
if not get_item.type == "TensorListGetItem":
32+
continue
33+
x_e = get_item.inputs[0]
34+
if not x_e.is_graph_input():
35+
continue
36+
x_idx = g.input_names.index(x_e.output[0])
37+
38+
ht_mul = match_result.get_op("ht")
39+
final_consumers = g.find_output_consumers(ht_mul.output[0])
40+
select_ops = [n for n in final_consumers if n.type == "Select"]
41+
def has_tensor_list_consumer(n):
42+
return any(c.type == "TensorListSetItem" for c in g.find_output_consumers(n.output[0]))
43+
select_ops = [n for n in select_ops if has_tensor_list_consumer(n)]
44+
if len(select_ops) == 1:
45+
greater_eq = select_ops[0].inputs[0]
46+
if greater_eq.type != "GreaterEqual":
47+
continue
48+
seq_len = greater_eq.inputs[1]
49+
if not seq_len.is_graph_input():
50+
continue
51+
seq_len_idx = g.input_names.index(seq_len.output[0])
52+
final_consumers = g.find_output_consumers(select_ops[0].output[0])
53+
else:
54+
seq_len_idx = None
55+
56+
tensor_set_items = [n for n in final_consumers if n.type == "TensorListSetItem"]
57+
if len(tensor_set_items) != 1:
58+
continue
59+
60+
if not tensor_set_items[0].inputs[0].is_graph_input():
61+
continue
62+
out_idx = g.input_names.index(tensor_set_items[0].input[0])
63+
64+
if concat.inputs[1].is_graph_input():
65+
# c and h are separate
66+
h_idx = g.input_names.index(concat.input[1])
67+
c_e = match_result.get_op("c")
68+
if not c_e.is_graph_input():
69+
continue
70+
c_idx = g.input_names.index(c_e.output[0])
71+
ch_info = {
72+
"state_is_tuple": True,
73+
"c_idx": c_idx,
74+
"h_idx": h_idx,
75+
}
76+
else:
77+
# c and h are concatenated
78+
if not concat.inputs[1].type == "Slice":
79+
continue
80+
ch_e = concat.inputs[1].inputs[0]
81+
if not ch_e.is_graph_input():
82+
continue
83+
ch_idx = g.input_names.index(ch_e.output[0])
84+
85+
c_e = match_result.get_op("c")
86+
if not c_e.type == "Slice" or c_e.input[0] != ch_e.output[0]:
87+
continue
88+
ch_info = {
89+
"state_is_tuple": False,
90+
"ch_idx": ch_idx,
91+
}
92+
93+
w_e = match_result.get_op("cell_kernel")
94+
if not w_e.is_graph_input():
95+
continue
96+
w_idx = g.input_names.index(w_e.output[0])
97+
98+
bias_add = match_result.get_op("bias_add")
99+
if bias_add is not None and bias_add.data_format != "NHWC":
100+
continue
101+
102+
b_e = match_result.get_op("cell_bias")
103+
if not b_e.is_graph_input():
104+
continue
105+
b_idx = g.input_names.index(b_e.output[0])
106+
107+
ft_bias_node = match_result.get_op("ft_bias")
108+
if not ft_bias_node.is_const():
109+
continue
110+
if g.get_dtype(ft_bias_node.output[0]) != g.get_dtype(b_e.output[0]):
111+
continue
112+
ft_bias = ft_bias_node.get_tensor_value(as_list=False)
113+
114+
g.lstm_rewriter_context = {
115+
"x_idx": x_idx,
116+
"out_idx": out_idx,
117+
"weight_idx": w_idx,
118+
"bias_idx": b_idx,
119+
"ft_bias": ft_bias,
120+
"seq_len_idx": seq_len_idx,
121+
**ch_info
122+
}
123+
124+
for op in ops:
125+
if op.is_while():
126+
body_graph = find_function(op.get_attr_str("body"))
127+
if body_graph.lstm_rewriter_context is None:
128+
continue
129+
body_context = body_graph.lstm_rewriter_context
130+
w = op.input[body_context["weight_idx"]]
131+
b = op.input[body_context["bias_idx"]]
132+
if not g.is_const(w) or not g.is_const(b):
133+
continue
134+
w_const = g.get_tensor_value(w, as_list=False)
135+
b_const = g.get_tensor_value(b, as_list=False)
136+
137+
if body_context["state_is_tuple"]:
138+
initial_c_sq = op.input[body_context["c_idx"]]
139+
initial_h_sq = op.input[body_context["h_idx"]]
140+
initial_c = GraphBuilder(g).make_unsqueeze({"data": initial_c_sq, "axes": [0]})
141+
initial_h = GraphBuilder(g).make_unsqueeze({"data": initial_h_sq, "axes": [0]})
142+
else:
143+
initial_ch = op.input[body_context["ch_idx"]]
144+
if not g.is_const(initial_ch):
145+
continue
146+
initial_ch_const = g.get_tensor_value(initial_ch, as_list=False)
147+
if not len(initial_ch_const.shape) == 2:
148+
continue
149+
initial_ch_const = np.expand_dims(initial_ch_const, axis=0)
150+
initial_c_const, initial_h_const = np.split(initial_ch_const, 2, axis=2)
151+
initial_c = g.make_const(utils.make_name("initial_c"), initial_c_const).output[0]
152+
initial_h = g.make_const(utils.make_name("initial_h"), initial_h_const).output[0]
153+
154+
context = LSTMContext()
155+
context.weights.append({"weight": w_const, "bias": b_const, "ft_bias": body_context["ft_bias"]})
156+
context.onnx_input_ids.append({})
157+
context.input_size.append(None)
158+
context.hidden_size.append(None)
159+
context.attributes.append({})
160+
tensor_array_inp = op.inputs[body_context["x_idx"]]
161+
if not tensor_array_inp.type == "TensorListFromTensor":
162+
continue
163+
164+
final_consumers = g.find_output_consumers(op.output[body_context["out_idx"]])
165+
output_ys = [n.output[0] for n in final_consumers if n.type == "TensorListStack"]
166+
167+
context.onnx_input_ids[0]["X"] = tensor_array_inp.input[0]
168+
if body_context["seq_len_idx"] is None:
169+
context.onnx_input_ids[0]["sequence_lens"] = ""
170+
else:
171+
context.onnx_input_ids[0]["sequence_lens"] = op.input[body_context["seq_len_idx"]]
172+
context.onnx_input_ids[0]["initial_c"] = initial_c
173+
context.onnx_input_ids[0]["initial_h"] = initial_h
174+
175+
lstm_rewriter = LSTMRewriter(g)
176+
lstm_rewriter.num_lstm_layers = 1
177+
lstm_rewriter.process_weights_and_bias(context)
178+
lstm_node = lstm_rewriter.create_rnn_node(context)[0]
179+
squeeze_output = GraphBuilder(g).make_squeeze({"data": lstm_node.output[0], "axes": [1]})
180+
for output in output_ys:
181+
g.replace_all_inputs(output, squeeze_output)
182+
183+
if body_context["state_is_tuple"]:
184+
c_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[2], "axes": [0]})
185+
h_squeeze = GraphBuilder(g).make_squeeze({"data": lstm_node.output[1], "axes": [0]})
186+
g.replace_all_inputs(op.output[body_context["c_idx"]], c_squeeze)
187+
g.replace_all_inputs(op.output[body_context["h_idx"]], h_squeeze)
188+
else:
189+
concat_ch = g.make_node("Concat", [lstm_node.output[2], lstm_node.output[1]],
190+
attr={"axis": 2}).output[0]
191+
ch_squeeze = GraphBuilder(g).make_squeeze({"data": concat_ch, "axes": [0]})
192+
ch_output = op.output[body_context["ch_idx"]]
193+
g.replace_all_inputs(ch_output, ch_squeeze)
194+
195+
return g.get_nodes()

tf2onnx/rewriter/rnn_utils.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,44 +30,51 @@ class REWRITER_RESULT(Enum):
3030

3131
# TensorFlow LSTMCell/BasicLSTMCell computation graph matching
3232

33-
xc_pattern = \
34-
OpTypePattern('Split', inputs=[
33+
_make_xc_pattern_memo = {}
34+
35+
def make_xc_pattern(enter_or_id="Enter"):
36+
return OpTypePattern('Split', inputs=[
3537
OpTypePattern("Const"), # axis for split
3638
OpTypePattern("BiasAdd", name="bias_add", inputs=[
3739
OpTypePattern("MatMul", inputs=[
3840
OpTypePattern("ConcatV2|Concat", name="xh"),
39-
OpTypePattern("Enter", inputs=[
41+
OpTypePattern(enter_or_id, inputs=[
4042
OpTypePattern("*", name="cell_kernel"),
4143
]),
4244
]),
43-
OpTypePattern("Enter", inputs=[
45+
OpTypePattern(enter_or_id, inputs=[
4446
OpTypePattern("*", name="cell_bias"),
4547
]),
4648
]),
4749
])
4850

49-
lstmcell_pattern = \
50-
OpTypePattern('Mul', name='ht', inputs=[
51-
OpTypePattern("Sigmoid", name="ot", inputs=[xc_pattern]),
51+
xc_pattern = make_xc_pattern()
52+
53+
def make_lstmcell_pattern(enter_or_id="Enter"):
54+
my_xc_pattern = make_xc_pattern(enter_or_id)
55+
return OpTypePattern('Mul', name='ht', inputs=[
56+
OpTypePattern("Sigmoid", name="ot", inputs=[my_xc_pattern]),
5257
OpTypePattern('Tanh', inputs=[
5358
OpTypePattern("Add|AddV2", name="ct", inputs=[
5459
OpTypePattern("Mul", name="ct_identity_consumer", inputs=[
5560
OpTypePattern("Sigmoid", name="ft", inputs=[
5661
OpTypePattern("Add|AddV2", inputs=[
57-
xc_pattern,
62+
my_xc_pattern,
5863
OpTypePattern("*", name="ft_bias"),
5964
]),
6065
]),
61-
OpTypePattern("*"),
66+
OpTypePattern("*", name="c"),
6267
]),
6368
OpTypePattern("Mul", inputs=[
64-
OpTypePattern("Sigmoid", name="it", inputs=[xc_pattern]),
65-
OpTypePattern("Tanh", name="gt", inputs=[xc_pattern]),
69+
OpTypePattern("Sigmoid", name="it", inputs=[my_xc_pattern]),
70+
OpTypePattern("Tanh", name="gt", inputs=[my_xc_pattern]),
6671
]),
6772
]),
6873
]),
6974
])
7075

76+
lstmcell_pattern = make_lstmcell_pattern()
77+
7178
xc_pattern_optimized = \
7279
OpTypePattern('Split', inputs=[
7380
OpTypePattern("Const"),

tf2onnx/tfonnx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def compat_handler(ctx, node, **kwargs):
606606
rewrite_leakyrelu,
607607
rewrite_thresholded_relu,
608608
rewrite_conv2d_with_pad,
609+
rewriter_lstm_tf2,
609610
rewrite_single_direction_lstm,
610611
# bi-directional
611612
rewrite_bi_direction_lstm,

0 commit comments

Comments
 (0)