Skip to content

Commit 92374ac

Browse files
committed
set max tf version to test
1 parent 8895b31 commit 92374ac

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

tests/test_cudnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313

1414
from tensorflow.python.ops import init_ops
1515
from backend_test_base import Tf2OnnxBackendTestBase
16-
from common import skip_tf2, skip_tf_cpu, check_opset_min_version, unittest_main
16+
from common import check_tf_max_version, skip_tf_cpu, check_opset_min_version, unittest_main
1717

1818

1919
class CudnnTests(Tf2OnnxBackendTestBase):
2020
""" test cudnn cases """
21-
@skip_tf2()
21+
@check_tf_max_version("1.15.0", "not supported in tf-2.0")
2222
@skip_tf_cpu("only tf_gpu can run CudnnGPU")
23-
@check_opset_min_version(11, "CudnnGRU")
23+
@check_opset_min_version(10, "CudnnGRU")
2424
def test_cudnngru(self):
2525
""" test contrib cudnn gru """
2626
seq_length = 3

tf2onnx/onnx_opset/rnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def version_7(cls, ctx, node, **kwargs):
173173
@tf_op("CudnnRNN")
174174
class CudnnRNN:
175175
@classmethod
176-
def version_11(cls, ctx, node, **kwargs):
176+
def version_10(cls, ctx, node, **kwargs):
177177
x = node.input[0]
178178
x_shape = ctx.get_shape(x)
179179
h = node.input[1]

0 commit comments

Comments
 (0)