8
8
import mock_keras2onnx
9
9
from mock_keras2onnx .proto import keras , is_keras_older_than
10
10
from mock_keras2onnx .proto .tfcompat import is_tf2
11
+ from packaging .version import Version
12
+ from tf2onnx .keras2onnx_api import convert_keras
11
13
import time
12
14
import json
13
15
import urllib
14
16
17
+
18
+ # Mapping opset to ONNXRuntime version.
19
+ ORT_OPSET_VERSION = {
20
+ "1.6.0" : 13 , "1.7.0" : 13 , "1.8.0" : 14 , "1.9.0" : 15 , "1.10.0" : 15 , "1.11.0" : 16 , "1.12.0" : 17
21
+ }
22
+
15
23
working_path = os .path .abspath (os .path .dirname (__file__ ))
16
24
tmp_path = os .path .join (working_path , 'temp' )
17
25
test_level_0 = True
@@ -299,3 +307,26 @@ def is_bloburl_access(url):
299
307
return response .getcode () == 200
300
308
except urllib .error .URLError :
301
309
return False
310
+
311
+
312
+ def get_max_opset_supported_by_ort ():
313
+ try :
314
+ import onnxruntime as ort
315
+ ort_ver = Version (ort .__version__ ).base_version
316
+
317
+ if ort_ver in ORT_OPSET_VERSION .keys ():
318
+ return ORT_OPSET_VERSION [ort_ver ]
319
+ else :
320
+ print ("Given onnxruntime version doesn't exist in ORT_OPSET_VERSION: {}" .format (ort_ver ))
321
+ return None
322
+ except ImportError :
323
+ return None
324
+
325
+
326
+ def convert_keras_for_test (model , name = None , target_opset = None , ** kwargs ):
327
+ if target_opset is None :
328
+ target_opset = get_max_opset_supported_by_ort ()
329
+
330
+ print ("Trying to run test with opset version: {}" .format (target_opset ))
331
+
332
+ return convert_keras (model = model , name = name , target_opset = target_opset , ** kwargs )
0 commit comments