16
16
import tensorflow as tf
17
17
18
18
from tensorflow .python .ops import lookup_ops
19
+ from tensorflow .python .ops import init_ops
19
20
from backend_test_base import Tf2OnnxBackendTestBase
20
21
# pylint reports unused-wildcard-import which is false positive, __all__ is defined in common
21
22
from common import * # pylint: disable=wildcard-import,unused-wildcard-import
22
23
from tf2onnx import constants , utils
23
24
from tf2onnx .graph_matcher import OpTypePattern , GraphMatcher
24
25
from tf2onnx .tf_loader import is_tf2
25
- from tensorflow .python .ops import init_ops
26
26
27
27
# pylint: disable=missing-docstring,invalid-name,unused-argument,function-redefined,cell-var-from-loop
28
28
@@ -2919,24 +2919,6 @@ def func(query_holder):
2919
2919
self ._run_test_case (func , [_OUTPUT ], {_INPUT : query }, constant_fold = False )
2920
2920
os .remove (filnm )
2921
2921
2922
- @check_opset_min_version (11 , "GRU" )
2923
- def test_cudnngru (self ):
2924
- seq_length = 3
2925
- batch_size = 5
2926
- input_size = 2
2927
- num_layers = 2
2928
- num_units = 2
2929
- num_dirs = 2
2930
- initializer = init_ops .constant_initializer (0.5 )
2931
- x = np .random .randint (0 , 100 , [seq_length , batch_size , input_size ]).astype (np .float32 )
2932
- h = np .random .randint (0 , 100 , [num_layers * num_dirs , batch_size , num_units ]).astype (np .float32 ).reshape (
2933
- [num_layers * num_dirs , batch_size , num_units ])
2934
- cudnngru = tf .contrib .cudnn_rnn .CudnnGRU (num_layers , num_units , 'linear_input' , 'bidirectional' ,
2935
- kernel_initializer = initializer , bias_initializer = initializer )
2936
- cudnngru .build ([seq_length , batch_size , input_size ])
2937
- outputs = cudnngru .call (x , tuple ([h ]))
2938
- self .run_test_case ({}, [], [outputs [0 ].name ], rtol = 1e-05 , atol = 1e-04 )
2939
-
2940
2922
@check_opset_min_version (11 )
2941
2923
def test_matrix_diag_part (self ):
2942
2924
input_vals = [
@@ -2951,6 +2933,26 @@ def func(input_holder):
2951
2933
for input_val in input_vals :
2952
2934
self ._run_test_case (func , [_OUTPUT ], {_INPUT : input_val })
2953
2935
2936
+ @check_opset_min_version (11 , "GRU" )
2937
+ def test_cudnngru (self ):
2938
+ def func ():
2939
+ seq_length = 3
2940
+ batch_size = 5
2941
+ input_size = 2
2942
+ num_layers = 2
2943
+ num_units = 2
2944
+ num_dirs = 2
2945
+ initializer = init_ops .constant_initializer (0.5 )
2946
+ x = np .random .randint (0 , 100 , [seq_length , batch_size , input_size ]).astype (np .float32 )
2947
+ h = np .random .randint (0 , 100 , [num_layers * num_dirs , batch_size , num_units ]).astype (np .float32 ).reshape (
2948
+ [num_layers * num_dirs , batch_size , num_units ])
2949
+ cudnngru = tf .contrib .cudnn_rnn .CudnnGRU (num_layers , num_units , 'linear_input' , 'bidirectional' ,
2950
+ kernel_initializer = initializer , bias_initializer = initializer )
2951
+ cudnngru .build ([seq_length , batch_size , input_size ])
2952
+ outputs = cudnngru .call (x , tuple ([h ]))
2953
+ _ = tf .identity (outputs [0 ], name = _TFOUTPUT )
2954
+ self .run_test_case (func , {}, [], [_OUTPUT ], rtol = 1e-05 , atol = 1e-04 )
2955
+
2954
2956
2955
2957
if __name__ == '__main__' :
2956
2958
unittest_main ()
0 commit comments