|
16 | 16 | # pylint: disable=abstract-method,arguments-differ
|
17 | 17 |
|
18 | 18 | if is_tf2():
|
| 19 | + # no test for tf2 in this file |
19 | 20 | pass
|
20 | 21 | else:
|
21 | 22 | LSTMBlockCell = tf.contrib.rnn.LSTMBlockCell
|
|
26 | 27 | dynamic_rnn = tf.nn.dynamic_rnn
|
27 | 28 | bidirectional_dynamic_rnn = tf.nn.bidirectional_dynamic_rnn
|
28 | 29 |
|
| 30 | + class GatedGRUCell(RNNCell): |
| 31 | + def __init__(self, hidden_dim, reuse=None): |
| 32 | + super().__init__(self, _reuse=reuse) |
| 33 | + self._num_units = hidden_dim |
| 34 | + self._activation = tf.tanh |
| 35 | + |
| 36 | + @property |
| 37 | + def state_size(self): |
| 38 | + return self._num_units |
| 39 | + |
| 40 | + @property |
| 41 | + def output_size(self): |
| 42 | + return self._num_units |
| 43 | + |
| 44 | + def call(self, inputs, state): |
| 45 | + # inputs shape: [batch size, time step, input size] = [1, 3, 2] |
| 46 | + # num_units: 5 |
| 47 | + # W shape: [2, 3 * 5] = [2, 15] |
| 48 | + # U shape: [5, 3 * 5] = [5, 15] |
| 49 | + # b shape: [1, 3 * 5] = [1, 15] |
| 50 | + # state shape: [batch size, state size] = [1, 5] |
| 51 | + |
| 52 | + input_dim = inputs.get_shape()[-1] |
| 53 | + assert input_dim is not None, "input dimension must be defined" |
| 54 | + # W = tf.get_variable(name="W", shape=[input_dim, 3 * self._num_units], dtype=tf.float32) |
| 55 | + W = np.arange(30.0, dtype=np.float32).reshape((2, 15)) |
| 56 | + # U = tf.get_variable(name='U', shape=[self._num_units, 3 * self._num_units], dtype=tf.float32) |
| 57 | + U = np.arange(75.0, dtype=np.float32).reshape((5, 15)) |
| 58 | + # b = tf.get_variable(name='b', shape=[1, 3 * self._num_units], dtype=tf.float32) |
| 59 | + b = np.arange(15.0, dtype=np.float32).reshape((1, 15)) |
| 60 | + |
| 61 | + xw = tf.split(tf.matmul(inputs, W) + b, 3, 1) |
| 62 | + hu = tf.split(tf.matmul(state, U), 3, 1) |
| 63 | + r = tf.sigmoid(xw[0] + hu[0]) |
| 64 | + z = tf.sigmoid(xw[1] + hu[1]) |
| 65 | + h1 = self._activation(xw[2] + r * hu[2]) |
| 66 | + next_h = h1 * (1 - z) + state * z |
| 67 | + return next_h, next_h |
| 68 | + |
29 | 69 |
|
30 | 70 | class CustomRnnCellTests(Tf2OnnxBackendTestBase):
|
31 | 71 | @check_opset_min_version(8, "Scan")
|
@@ -370,45 +410,5 @@ def func(encoder_x, decoder_x, seq_length):
|
370 | 410 | self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, 0.1)
|
371 | 411 |
|
372 | 412 |
|
373 |
| -class GatedGRUCell(RNNCell): |
374 |
| - def __init__(self, hidden_dim, reuse=None): |
375 |
| - super().__init__(self, _reuse=reuse) |
376 |
| - self._num_units = hidden_dim |
377 |
| - self._activation = tf.tanh |
378 |
| - |
379 |
| - @property |
380 |
| - def state_size(self): |
381 |
| - return self._num_units |
382 |
| - |
383 |
| - @property |
384 |
| - def output_size(self): |
385 |
| - return self._num_units |
386 |
| - |
387 |
| - def call(self, inputs, state): |
388 |
| - # inputs shape: [batch size, time step, input size] = [1, 3, 2] |
389 |
| - # num_units: 5 |
390 |
| - # W shape: [2, 3 * 5] = [2, 15] |
391 |
| - # U shape: [5, 3 * 5] = [5, 15] |
392 |
| - # b shape: [1, 3 * 5] = [1, 15] |
393 |
| - # state shape: [batch size, state size] = [1, 5] |
394 |
| - |
395 |
| - input_dim = inputs.get_shape()[-1] |
396 |
| - assert input_dim is not None, "input dimension must be defined" |
397 |
| - # W = tf.get_variable(name="W", shape=[input_dim, 3 * self._num_units], dtype=tf.float32) |
398 |
| - W = np.arange(30.0, dtype=np.float32).reshape((2, 15)) |
399 |
| - # U = tf.get_variable(name='U', shape=[self._num_units, 3 * self._num_units], dtype=tf.float32) |
400 |
| - U = np.arange(75.0, dtype=np.float32).reshape((5, 15)) |
401 |
| - # b = tf.get_variable(name='b', shape=[1, 3 * self._num_units], dtype=tf.float32) |
402 |
| - b = np.arange(15.0, dtype=np.float32).reshape((1, 15)) |
403 |
| - |
404 |
| - xw = tf.split(tf.matmul(inputs, W) + b, 3, 1) |
405 |
| - hu = tf.split(tf.matmul(state, U), 3, 1) |
406 |
| - r = tf.sigmoid(xw[0] + hu[0]) |
407 |
| - z = tf.sigmoid(xw[1] + hu[1]) |
408 |
| - h1 = self._activation(xw[2] + r * hu[2]) |
409 |
| - next_h = h1 * (1 - z) + state * z |
410 |
| - return next_h, next_h |
411 |
| - |
412 |
| - |
413 | 413 | if __name__ == '__main__':
|
414 | 414 | unittest_main()
|
0 commit comments