|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +"""Unit Tests for cudnn.""" |
| 5 | + |
| 6 | +from __future__ import absolute_import |
| 7 | +from __future__ import division |
| 8 | +from __future__ import print_function |
| 9 | +from __future__ import unicode_literals |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import tensorflow as tf |
| 13 | + |
| 14 | +from tensorflow.python.ops import init_ops |
| 15 | +from tensorflow.python.ops import variable_scope |
| 16 | +from backend_test_base import Tf2OnnxBackendTestBase |
| 17 | +from common import * |
| 18 | +from tf2onnx.tf_loader import is_tf2 |
| 19 | + |
| 20 | + |
| 21 | +class CudnnTests(Tf2OnnxBackendTestBase): |
| 22 | + @skip_tf2() |
| 23 | + @skip_tf_cpu("only tf_gpu can run CudnnGPU") |
| 24 | + @check_opset_min_version(11, "CudnnGRU") |
| 25 | + def test_cudnngru(self): |
| 26 | + seq_length = 3 |
| 27 | + batch_size = 5 |
| 28 | + input_size = 2 |
| 29 | + num_layers = 2 |
| 30 | + num_units = 2 |
| 31 | + num_dirs = 2 |
| 32 | + x_val = np.random.randint(0, 100, [seq_length, batch_size, input_size]).astype(np.float32) |
| 33 | + h_val = np.random.randint(0, 100, [num_layers * num_dirs, batch_size, num_units]).astype(np.float32).reshape( |
| 34 | + [num_layers * num_dirs, batch_size, num_units]) |
| 35 | + |
| 36 | + def func(x, h): |
| 37 | + initializer = init_ops.constant_initializer(0.5) |
| 38 | + cudnngru = tf.contrib.cudnn_rnn.CudnnGRU(num_layers, num_units, 'linear_input', 'bidirectional', |
| 39 | + kernel_initializer=initializer, bias_initializer=initializer) |
| 40 | + cudnngru.build([seq_length, batch_size, input_size]) |
| 41 | + outputs = cudnngru.call(x, tuple([h])) |
| 42 | + _ = tf.identity(outputs[0], name='output') |
| 43 | + |
| 44 | + feed_dict = {"input_1:0": x_val, "input_2:0": h_val} |
| 45 | + input_names_with_port = ["input_1:0", "input_2:0"] |
| 46 | + output_names_with_port = ["output:0"] |
| 47 | + self.run_test_case(func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-05, atol=1e-04) |
| 48 | + |
| 49 | + |
| 50 | +if __name__ == '__main__': |
| 51 | + unittest_main() |
0 commit comments