|
10 | 10 | from __future__ import unicode_literals
|
11 | 11 |
|
12 | 12 | import logging
|
13 |
| - |
14 | 13 | import numpy as np
|
15 | 14 | from tf2onnx import utils
|
16 | 15 | from tf2onnx.handler import tf_op
|
@@ -169,3 +168,88 @@ def replace_output(old_output, new_output):
|
169 | 168 | @classmethod
|
170 | 169 | def version_7(cls, ctx, node, **kwargs):
|
171 | 170 | cls.version_1(ctx, node, **kwargs)
|
| 171 | + |
| 172 | + |
| 173 | +@tf_op("CudnnRNN") |
| 174 | +class CudnnRNN: |
| 175 | + @classmethod |
| 176 | + def version_10(cls, ctx, node, **kwargs): |
| 177 | + x = node.input[0] |
| 178 | + x_shape = ctx.get_shape(x) |
| 179 | + h = node.input[1] |
| 180 | + h_shape = ctx.get_shape(h) |
| 181 | + p = node.input[3] |
| 182 | + utils.make_sure( |
| 183 | + node.attr["rnn_mode"].s == b"gru", |
| 184 | + "rnn mode other than gru are not supported yet" |
| 185 | + ) |
| 186 | + utils.make_sure( |
| 187 | + node.attr["dropout"].f == 0, |
| 188 | + "dropout not supported yet" |
| 189 | + ) |
| 190 | + utils.make_sure( |
| 191 | + node.attr["input_mode"].s == b"linear_input", |
| 192 | + "input mode must be linear input" |
| 193 | + ) |
| 194 | + num_dirs = 1 if node.attr["direction"].s == b"unidirectional" else 2 |
| 195 | + num_layers = int(h_shape[0] / num_dirs) |
| 196 | + num_units = hidden_size = h_shape[2] |
| 197 | + input_size = x_shape[2] |
| 198 | + w_shape = [num_layers * num_dirs, 3 * hidden_size, input_size] |
| 199 | + w_shape_const = ctx.make_const(utils.make_name("w_shape"), np.array(w_shape, dtype=np.int64)) |
| 200 | + r_shape = [num_layers * num_dirs, 3 * hidden_size, hidden_size] |
| 201 | + r_shape_const = ctx.make_const(utils.make_name("r_shape"), np.array(r_shape, dtype=np.int64)) |
| 202 | + b_shape = [num_layers * num_dirs, 6 * hidden_size] |
| 203 | + b_shape_const = ctx.make_const(utils.make_name("b_shape"), np.array(b_shape, dtype=np.int64)) |
| 204 | + zero_const = ctx.make_const(utils.make_name("zero"), np.array([0], dtype=np.int64)) |
| 205 | + w_end = np.prod(w_shape) |
| 206 | + w_end_const = ctx.make_const(utils.make_name("w_end"), np.array([w_end], dtype=np.int64)) |
| 207 | + r_end = w_end + np.prod(r_shape) |
| 208 | + r_end_const = ctx.make_const(utils.make_name("r_end"), np.array([r_end], dtype=np.int64)) |
| 209 | + b_end = r_end + np.prod(b_shape) |
| 210 | + b_end_const = ctx.make_const(utils.make_name("b_end"), np.array([b_end], dtype=np.int64)) |
| 211 | + |
| 212 | + def name(nm): |
| 213 | + return node.name + "_" + nm |
| 214 | + |
| 215 | + ws = [name('W_' + str(i)) for i in range(num_layers * num_dirs)] |
| 216 | + rs = [name('R_' + str(i)) for i in range(num_layers * num_dirs)] |
| 217 | + bs = [name('B_' + str(i)) for i in range(num_layers * num_dirs)] |
| 218 | + hs = [name('H_' + str(i)) for i in range(num_layers * num_dirs)] |
| 219 | + yhs = [name('YH_' + str(i)) for i in range(num_layers * num_dirs)] |
| 220 | + w_flattened = ctx.make_node('Slice', [p, zero_const.output[0], w_end_const.output[0]]) |
| 221 | + r_flattened = ctx.make_node('Slice', [p, w_end_const.output[0], r_end_const.output[0]]) |
| 222 | + b_flattened = ctx.make_node('Slice', [p, r_end_const.output[0], b_end_const.output[0]]) |
| 223 | + w = utils.make_name('W') |
| 224 | + r = utils.make_name('R') |
| 225 | + b = utils.make_name('B') |
| 226 | + ctx.make_node('Reshape', [w_flattened.output[0], w_shape_const.output[0]], outputs=[w]) |
| 227 | + ctx.make_node('Reshape', [r_flattened.output[0], r_shape_const.output[0]], outputs=[r]) |
| 228 | + ctx.make_node('Reshape', [b_flattened.output[0], b_shape_const.output[0]], outputs=[b]) |
| 229 | + ctx.make_node('Split', [w], outputs=ws) |
| 230 | + ctx.make_node('Split', [r], outputs=rs) |
| 231 | + ctx.make_node('Split', [b], outputs=bs) |
| 232 | + ctx.make_node('Split', [h], outputs=hs) |
| 233 | + xnf = xnb = x |
| 234 | + for i in range(num_layers): |
| 235 | + suffix = '_' + str(i * num_dirs) |
| 236 | + ctx.make_node('GRU', |
| 237 | + [xnf, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix)], |
| 238 | + outputs=[name('Y' + suffix), name('YH' + suffix)], |
| 239 | + attr={'direction': 'forward', 'hidden_size': num_units}) |
| 240 | + xnf = name(x + suffix) |
| 241 | + ctx.make_node('Squeeze', [name('Y' + suffix)], outputs=[xnf], attr={'axes': [1]}) |
| 242 | + if num_dirs == 2: |
| 243 | + suffix = '_' + str(i * 2 + 1) |
| 244 | + ctx.make_node('GRU', |
| 245 | + [xnb, name('W' + suffix), name('R' + suffix), name('B' + suffix), '', name('H' + suffix)], |
| 246 | + outputs=[name('Y' + suffix), name('YH' + suffix)], |
| 247 | + attr={'direction': 'reverse', 'hidden_size': num_units}) |
| 248 | + xnb = name(x + suffix) |
| 249 | + ctx.make_node('Squeeze', [name('Y' + suffix)], outputs=[xnb], attr={'axes': [1]}) |
| 250 | + ctx.remove_node(node.name) |
| 251 | + if num_dirs == 2: |
| 252 | + ctx.make_node('Concat', [xnf, xnb], outputs=[node.output[0]], attr={'axis': -1}) |
| 253 | + else: |
| 254 | + ctx.make_node('Identity', [xnf], outputs=[node.output[0]]) |
| 255 | + ctx.make_node('Concat', yhs, outputs=[node.output[1]], attr={'axis': 0}) |
0 commit comments