Skip to content

Commit 42d7222

Browse files
authored
Fix keras2onnx_unit_test nightly job opset version failures. (#2032)
* Update tests to fix the issue of keras2onnx_unit_test nightly job failures. Signed-off-by: Jay Zhang <[email protected]>
1 parent bc677a1 commit 42d7222

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

ci_build/azure_pipelines/keras2onnx_unit_test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ jobs:
6868
ONNX_PATH: -i onnx==1.12.0
6969
KERAS: keras==2.9.0
7070
TENSORFLOW_PATH: tensorflow==2.9.0
71-
INSTALL_ORT: pip install onnxruntime==1.11.0
71+
INSTALL_ORT: pip install onnxruntime==1.12.0
7272
INSTALL_NUMPY: pip install numpy==1.23.0
7373

7474
steps:

tests/keras2onnx_unit_tests/test_layers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mock_keras2onnx.proto import (keras, is_tf_keras,
88
is_tensorflow_older_than, is_tensorflow_later_than,
99
is_keras_older_than, is_keras_later_than, python_keras_is_deprecated)
10-
from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional
10+
from test_utils import no_loops_in_tf2, all_recurrents_should_bidirectional, convert_keras_for_test as convert_keras
1111

1212
K = keras.backend
1313
Activation = keras.layers.Activation
@@ -88,9 +88,6 @@
8888
def _asarray(*a):
8989
return np.array([a], dtype='f')
9090

91-
########################
92-
from tf2onnx.keras2onnx_api import convert_keras
93-
##########################
9491

9592
def test_keras_lambda(runner):
9693
model = Sequential()

tests/keras2onnx_unit_tests/test_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@
88
import mock_keras2onnx
99
from mock_keras2onnx.proto import keras, is_keras_older_than
1010
from mock_keras2onnx.proto.tfcompat import is_tf2
11+
from packaging.version import Version
12+
from tf2onnx.keras2onnx_api import convert_keras
1113
import time
1214
import json
1315
import urllib
1416

17+
18+
# Mapping opset to ONNXRuntime version.
19+
ORT_OPSET_VERSION = {
20+
"1.6.0": 13, "1.7.0": 13, "1.8.0": 14, "1.9.0": 15, "1.10.0": 15, "1.11.0": 16, "1.12.0": 17
21+
}
22+
1523
working_path = os.path.abspath(os.path.dirname(__file__))
1624
tmp_path = os.path.join(working_path, 'temp')
1725
test_level_0 = True
@@ -299,3 +307,26 @@ def is_bloburl_access(url):
299307
return response.getcode() == 200
300308
except urllib.error.URLError:
301309
return False
310+
311+
312+
def get_max_opset_supported_by_ort():
313+
try:
314+
import onnxruntime as ort
315+
ort_ver = Version(ort.__version__).base_version
316+
317+
if ort_ver in ORT_OPSET_VERSION.keys():
318+
return ORT_OPSET_VERSION[ort_ver]
319+
else:
320+
print("Given onnxruntime version doesn't exist in ORT_OPSET_VERSION: {}".format(ort_ver))
321+
return None
322+
except ImportError:
323+
return None
324+
325+
326+
def convert_keras_for_test(model, name=None, target_opset=None, **kwargs):
327+
if target_opset is None:
328+
target_opset = get_max_opset_supported_by_ort()
329+
330+
print("Trying to run test with opset version: {}".format(target_opset))
331+
332+
return convert_keras(model=model, name=name, target_opset=target_opset, **kwargs)

0 commit comments

Comments
 (0)