Skip to content

Commit 7e6899b

Browse files
committed
onnx#227 use CudnnCompatibleGRUCell for linear_with_reset
1 parent afef14f commit 7e6899b

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

onnx_tf/handlers/backend/gru.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ def args_check(cls, node, **kwargs):
1818
num_directions = 2 if direction == "bidirectional" else 1
1919
if "clip" in node.attrs:
2020
exception.OP_UNSUPPORTED_EXCEPT("GRU with clip", "Tensorflow")
21-
if node.attrs.get("linear_before_reset", 0):
22-
exception.OP_UNSUPPORTED_EXCEPT("GRU with linear_before_reset",
23-
"Tensorflow")
2421
if "activations" in node.attrs:
2522
activations = list(map(lambda x: x.lower(), node.attrs["activations"]))
2623
if activations[0] != "sigmoid":
@@ -63,11 +60,16 @@ def _custom_getter(cls,
6360
if names[-2] == "gates":
6461
new_w = tf.transpose(tf.concat([w_r, w_z], 0))
6562
new_r = tf.transpose(tf.concat([r_r, r_z], 0))
66-
elif names[-2] == "candidate":
63+
elif names[-2] == "candidate" or names[-3] == "candidate":
6764
new_w = tf.transpose(w_h)
6865
new_r = tf.transpose(r_h)
69-
kernel = tf.concat([new_w, new_r], 0)
70-
return kernel
66+
if names[-2] == 'input_projection':
67+
return new_w
68+
elif names[-2] == 'hidden_projection':
69+
return new_r
70+
else:
71+
return tf.concat([new_w, new_r], 0)
72+
7173
if names[-1] == "bias":
7274
if len(node.inputs) >= 4:
7375
# onnx Wb[zrh], Rb[zrh]
@@ -81,10 +83,15 @@ def _custom_getter(cls,
8183
if names[-2] == "gates":
8284
w_b = tf.transpose(tf.concat([w_b_r, w_b_z], 0))
8385
r_b = tf.transpose(tf.concat([r_b_r, r_b_z], 0))
84-
elif names[-2] == "candidate":
86+
elif names[-2] == "candidate" or names[-3] == "candidate":
8587
w_b = tf.transpose(w_b_h)
8688
r_b = tf.transpose(r_b_h)
87-
return tf.add(w_b, r_b)
89+
if names[-2] == 'input_projection':
90+
return w_b
91+
elif names[-2] == 'hidden_projection':
92+
return r_b
93+
else:
94+
return tf.add(w_b, r_b)
8895
return getter(name, *args, **kwargs)
8996
return getter(name, *args, **kwargs)
9097

@@ -158,7 +165,12 @@ def _common(cls, node, **kwargs):
158165
rnn_kwargs["time_major"] = True
159166
rnn_kwargs["dtype"] = tf.float32
160167

161-
outputs, states = cls.rnn(x, tf.nn.rnn_cell.GRUCell, cell_kwargs,
168+
if node.attrs.get("linear_before_reset", 0):
169+
cell_class = tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell
170+
else:
171+
cell_class = tf.nn.rnn_cell.GRUCell
172+
173+
outputs, states = cls.rnn(x, cell_class, cell_kwargs,
162174
rnn_kwargs, tf_activations, direction)
163175

164176
if num_directions == 1:

onnx_tf/handlers/backend/rnn_mixin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ class RNNMixin(object):
2424

2525
@classmethod
2626
def rnn(cls, x, cell_class, cell_kwargs, rnn_kwargs, activations, direction):
27-
cell_kwargs["activation"] = activations[0]
27+
if cell_class is not tf.contrib.cudnn_rnn.CudnnCompatibleGRUCell:
28+
cell_kwargs["activation"] = activations[0]
2829

2930
rnn_cell = [cell_class(**cell_kwargs)]
3031
cell_fw = tf.nn.rnn_cell.MultiRNNCell(rnn_cell)

0 commit comments

Comments
 (0)