Skip to content

Commit 186b954

Browse files
Tom/keras gru (#1692)
* Implement GRU rewriter for tf2 Signed-off-by: Tom Wildenhain <[email protected]> * Extend GRU rewriter for keras pattern Signed-off-by: Tom Wildenhain <[email protected]> * Fix style and GRU test Signed-off-by: Tom Wildenhain <[email protected]> * Remove no_loops_tf2 assertion from keras tests for now Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent be36457 commit 186b954

File tree

7 files changed

+143
-23
lines changed

7 files changed

+143
-23
lines changed

tests/keras2onnx_unit_tests/test_layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from mock_keras2onnx.proto import (keras, is_tf_keras,
88
is_tensorflow_older_than, is_tensorflow_later_than,
99
is_keras_older_than, is_keras_later_than)
10+
from test_utils import no_loops_in_tf2
1011

1112
K = keras.backend
1213
Activation = keras.layers.Activation
@@ -1864,7 +1865,7 @@ def test_GRU_2(runner):
18641865
onnx_model = convert_keras(model, name=model.name)
18651866
data = np.random.rand(3, 257).astype(np.float32).reshape((3, 1, 257))
18661867
expected = model.predict(data)
1867-
runner(onnx_model.graph.name, onnx_model, data, expected)
1868+
assert runner(onnx_model.graph.name, onnx_model, data, expected)
18681869

18691870

18701871
@pytest.mark.parametrize('return_sequences', [False, True])

tests/keras2onnx_unit_tests/test_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ def parse_profile_results(sess_time, kernel_time_only=False, threshold=0):
157157
return results
158158

159159

160+
def no_loops_in_tf2(onnx_model):
161+
return not is_tf2 or all(n.op_type != "Loop" for n in onnx_model.graph.node)
162+
163+
160164
def run_onnx_runtime(case_name, onnx_model, data, expected, model_files, rtol=1.e-3, atol=1.e-6,
161165
compare_perf=False, enable_profiling=False):
162166
if not os.path.exists(tmp_path):

tests/test_gru.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,29 @@
99
from tensorflow.python.ops import init_ops
1010
from tensorflow.python.ops import variable_scope
1111
from backend_test_base import Tf2OnnxBackendTestBase
12-
from common import unittest_main, check_gru_count, check_opset_after_tf_version, check_op_count
12+
from common import unittest_main, check_gru_count, check_opset_after_tf_version, check_op_count, check_tf_min_version
1313
from tf2onnx.tf_loader import is_tf2
1414

1515
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,cell-var-from-loop
1616

17+
# names for input and outputs for tests
18+
_TFINPUT = "input"
19+
_INPUT = "input:0"
20+
_TFINPUT1 = "input1"
21+
_INPUT1 = "input1:0"
22+
_TFINPUT2 = "input2"
23+
_INPUT2 = "input2:0"
24+
_TFINPUT3 = "input3"
25+
_INPUT3 = "input3:0"
26+
_TFOUTPUT = "output"
27+
_OUTPUT = "output:0"
28+
_TFOUTPUT1 = "output1"
29+
_OUTPUT1 = "output1:0"
30+
_TFOUTPUT2 = "output2"
31+
_OUTPUT2 = "output2:0"
32+
_TFOUTPUT3 = "output3"
33+
_OUTPUT3 = "output3:0"
34+
1735
if is_tf2():
1836
# There is no LSTMBlockCell in tf-2.x
1937
BasicLSTMCell = tf.compat.v1.nn.rnn_cell.BasicLSTMCell
@@ -696,6 +714,23 @@ def func(x, y1, y2):
696714
self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-3, atol=1e-06)
697715
# graph_validator=lambda g: check_gru_count(g, 2))
698716

717+
@check_tf_min_version("2.0")
718+
def test_keras_gru(self):
719+
in_shape = [10, 3]
720+
x_val = np.random.uniform(size=[2, 10, 3]).astype(np.float32)
721+
722+
model_in = tf.keras.layers.Input(tuple(in_shape), batch_size=2)
723+
x = tf.keras.layers.GRU(5, return_sequences=True, return_state=True,
724+
kernel_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=42),
725+
recurrent_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=44),
726+
bias_initializer=tf.random_uniform_initializer(0.0, 1.0, seed=43))(model_in)
727+
model = tf.keras.models.Model(inputs=model_in, outputs=x)
728+
729+
def func(x):
730+
y = model(x)
731+
return tf.identity(y[0], name=_TFOUTPUT), tf.identity(y[1], name=_TFOUTPUT1)
732+
self.run_test_case(func, {_INPUT: x_val}, [], [_OUTPUT, _OUTPUT1], rtol=1e-05, atol=1e-06)
733+
699734

700735
if __name__ == '__main__':
701736
unittest_main()

tf2onnx/rewriter/gru_rewriter.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,40 @@ def is_valid(self, context):
139139
return False
140140
return True
141141

142+
def _make_constants(self, context, W_zrh, R_zrh, B_zrh):
143+
input_size = W_zrh.shape[-1]
144+
hidden_size = R_zrh.shape[-1]
145+
w_name = utils.make_name("W")
146+
w_node = self.g.make_const(w_name, W_zrh, skip_conversion=True)
147+
148+
r_name = utils.make_name("R")
149+
r_node = self.g.make_const(r_name, R_zrh, skip_conversion=True)
150+
151+
b_name = utils.make_name("B")
152+
b_node = self.g.make_const(b_name, B_zrh, skip_conversion=True)
153+
154+
context.input_size = input_size
155+
context.hidden_size = hidden_size
156+
context.onnx_input_ids["W"] = w_node.output[0]
157+
context.onnx_input_ids["R"] = r_node.output[0]
158+
context.onnx_input_ids["B"] = b_node.output[0]
159+
160+
def _process_weights_and_bias_keras(self, context):
161+
weights = context.weights
162+
W_zrh = np.expand_dims(weights["gate_kernel"].transpose(), axis=0)
163+
R_zrh = np.expand_dims(weights["hidden_kernel"].transpose(), axis=0)
164+
Wb_zrh = weights["gate_bias"]
165+
Rb_zrh = weights["hidden_bias"]
166+
B_zrh = np.expand_dims(np.concatenate((Wb_zrh, Rb_zrh), axis=0), axis=0)
167+
self._make_constants(context, W_zrh, R_zrh, B_zrh)
168+
142169
def process_weights_and_bias(self, context):
143170
"""
144171
why split the data in this way should refer to code of tensorflow GRU cell and official document of ONNX GRU
145172
"""
173+
if context.from_keras:
174+
self._process_weights_and_bias_keras(context)
175+
return
146176
weights = context.weights
147177
# from code of tensorflow GRU cell, it can be known that shape of hidden_kernel(or candidate_kernel)
148178
# is (input_size+hidden_unit, hidden_unit)
@@ -157,6 +187,8 @@ def process_weights_and_bias(self, context):
157187
h_kernel = weights["hidden_kernel"]
158188
r_bias, z_bias = np.split(weights["gate_bias"], [hidden_size], axis=0)
159189
h_bias = weights["hidden_bias"]
190+
for k in sorted(weights.keys()):
191+
print(k, weights[k].shape)
160192
# ONNX GRU split weights of input and state, so have to split *_kernel
161193
input_r_kernel, state_r_kernel = np.split(r_kernel, [input_size], axis=0)
162194
input_z_kernel, state_z_kernel = np.split(z_kernel, [input_size], axis=0)
@@ -181,20 +213,7 @@ def process_weights_and_bias(self, context):
181213
B_zrh = B_zrh.astype(bias_dtype)
182214
assert B_zrh.shape == (1, 6*hidden_size)
183215
# create const ONNX node
184-
w_name = utils.make_name("W")
185-
w_node = self.g.make_const(w_name, W_zrh, skip_conversion=True)
186-
187-
r_name = utils.make_name("R")
188-
r_node = self.g.make_const(r_name, R_zrh, skip_conversion=True)
189-
190-
b_name = utils.make_name("B")
191-
b_node = self.g.make_const(b_name, B_zrh, skip_conversion=True)
192-
193-
context.input_size = input_size
194-
context.hidden_size = hidden_size
195-
context.onnx_input_ids["W"] = w_node.output[0]
196-
context.onnx_input_ids["R"] = r_node.output[0]
197-
context.onnx_input_ids["B"] = b_node.output[0]
216+
self._make_constants(context, W_zrh, R_zrh, B_zrh)
198217

199218
def process_var_init_nodes(self, context):
200219
assert "state" in context.state_variables.keys()

tf2onnx/rewriter/gru_tf2_rewriter.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
from tf2onnx.graph_matcher import GraphMatcher
9-
from tf2onnx.rewriter.rnn_utils import make_grucell_pattern
9+
from tf2onnx.rewriter.rnn_utils import make_grucell_pattern, keras_gru_pattern
1010
from tf2onnx.tf_loader import find_function
1111
from tf2onnx.rewriter.unit_rnn_rewriter_base import UnitRnnContext
1212
from tf2onnx.rewriter.gru_rewriter import GRUUnitRewriter
@@ -17,8 +17,9 @@
1717

1818
def rewrite_gru_tf2(g, ops):
1919
pattern1 = make_grucell_pattern("Identity")
20+
pattern2 = keras_gru_pattern
2021

21-
for pattern in [pattern1]:
22+
for pattern in [pattern1, pattern2]:
2223
matcher = GraphMatcher(pattern, allow_reorder=True)
2324
match_results = list(matcher.match_ops(ops))
2425
for match_result in match_results:
@@ -27,17 +28,21 @@ def rewrite_gru_tf2(g, ops):
2728
if activation_op.type not in ["Relu", "Tanh", "Sigmoid"]:
2829
continue
2930

30-
concat = match_result.get_op("cell_inputs")
31-
if len(concat.inputs) != 3:
32-
continue
33-
get_item = concat.inputs[0]
31+
if pattern is pattern1:
32+
concat = match_result.get_op("cell_inputs")
33+
if len(concat.inputs) != 3:
34+
continue
35+
get_item = concat.inputs[0]
36+
init_state = concat.inputs[1]
37+
else:
38+
get_item = match_result.get_op("gru_input")
39+
init_state = match_result.get_op("state")
3440
if not get_item.type == "TensorListGetItem":
3541
continue
3642
x_e = get_item.inputs[0]
3743
if not x_e.is_graph_input():
3844
continue
3945
x_idx = g.input_names.index(x_e.output[0])
40-
init_state = concat.inputs[1]
4146
if not init_state.is_graph_input():
4247
continue
4348
init_state_idx = g.input_names.index(init_state.output[0])
@@ -69,6 +74,8 @@ def has_tensor_list_consumer(n):
6974
out_idx = g.input_names.index(tensor_set_items[0].input[0])
7075

7176
hk = match_result.get_op("hidden_kernel")
77+
while hk.type == "Identity":
78+
hk = hk.inputs[0]
7279
if not hk.is_graph_input():
7380
continue
7481
hk_idx = g.input_names.index(hk.output[0])
@@ -79,6 +86,8 @@ def has_tensor_list_consumer(n):
7986
hb_idx = g.input_names.index(hb.output[0])
8087

8188
gk = match_result.get_op("gate_kernel")
89+
while gk.type == "Identity":
90+
gk = gk.inputs[0]
8291
if not gk.is_graph_input():
8392
continue
8493
gk_idx = g.input_names.index(gk.output[0])
@@ -102,6 +111,8 @@ def has_tensor_list_consumer(n):
102111
"gate_bias_idx": gb_idx,
103112
"seq_len_idx": seq_len_idx,
104113
"activations": activations,
114+
"from_keras": pattern is pattern2,
115+
"linear_before_reset": 1 if pattern is pattern2 else 0,
105116
}
106117

107118
for op in ops:
@@ -125,13 +136,15 @@ def has_tensor_list_consumer(n):
125136
initial_state = GraphBuilder(g).make_unsqueeze({"data": initial_state_sq, "axes": [0]})
126137

127138
context = UnitRnnContext()
139+
context.from_keras = body_context["from_keras"]
128140
context.weights.update({
129141
"hidden_kernel": hk_const,
130142
"hidden_bias": hb_const,
131143
"gate_kernel": gk_const,
132144
"gate_bias": gb_const
133145
})
134146
context.attributes["activations"] = body_context["activations"]
147+
context.attributes["linear_before_reset"] = body_context["linear_before_reset"]
135148
tensor_array_inp = op.inputs[body_context["x_idx"]]
136149
if not tensor_array_inp.type == "TensorListFromTensor":
137150
continue

tf2onnx/rewriter/rnn_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,53 @@ def make_grucell_pattern(enter_or_id="Enter"):
163163

164164
grucell_pattern = make_grucell_pattern()
165165

166+
def make_keras_gru_split_pattern(bias_name, kernel_name, input_name, input_op_type):
167+
return OpTypePattern("Split", inputs=[
168+
OpTypePattern("Const"),
169+
OpTypePattern("BiasAdd", inputs=[
170+
OpTypePattern("MatMul", inputs=[
171+
OpTypePattern(input_op_type, name=input_name),
172+
OpTypePattern("Placeholder|PlaceholderV2|Identity", name=kernel_name),
173+
], allow_reorder=False),
174+
OpTypePattern("Placeholder|PlaceholderV2", name=bias_name)
175+
])
176+
])
177+
178+
keras_gru_split0_pattern = make_keras_gru_split_pattern("gate_bias", "gate_kernel", "gru_input", "TensorListGetItem")
179+
keras_gru_split1_pattern = \
180+
make_keras_gru_split_pattern("hidden_bias", "hidden_kernel", "state", "Placeholder|PlaceholderV2")
181+
182+
keras_gru_sigmoid_pattern = \
183+
OpTypePattern("Sigmoid", inputs=[
184+
OpTypePattern("Add|AddV2", inputs=[
185+
keras_gru_split0_pattern,
186+
keras_gru_split1_pattern
187+
])
188+
])
189+
190+
keras_gru_pattern = \
191+
OpTypePattern("Add|AddV2", name="cell_output", inputs=[
192+
OpTypePattern("Mul", inputs=[
193+
keras_gru_sigmoid_pattern,
194+
OpTypePattern("Placeholder|PlaceholderV2")
195+
]),
196+
OpTypePattern("Mul", inputs=[
197+
OpTypePattern("Sub", inputs=[
198+
OpTypePattern("Const"),
199+
keras_gru_sigmoid_pattern
200+
], allow_reorder=False),
201+
OpTypePattern("*", name="optional_activation", inputs=[
202+
OpTypePattern("Add|AddV2", inputs=[
203+
keras_gru_split0_pattern,
204+
OpTypePattern("Mul", inputs=[
205+
keras_gru_sigmoid_pattern,
206+
keras_gru_split1_pattern
207+
])
208+
])
209+
])
210+
])
211+
])
212+
166213
cudnn_compatible_grucell_pattern = \
167214
OpTypePattern("Add", name="cell_output", inputs=[
168215
OpTypePattern("Mul", inputs=[

tf2onnx/rewriter/unit_rnn_rewriter_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(self):
3131
self.state_variables = {}
3232
self.input_size = None
3333
self.hidden_size = None
34+
self.from_keras = False
3435

3536
self.attributes = {} # onnx attributes
3637
# onnx inputs: [X, W, R, B, sequence_lens, initial_h, initial_c, P],

0 commit comments

Comments
 (0)