9
9
import unittest
10
10
from collections import defaultdict
11
11
12
- from distutils .version import LooseVersion
12
+ from packaging .version import Version
13
13
from parameterized import parameterized
14
14
import numpy as np
15
15
import tensorflow as tf
@@ -98,7 +98,7 @@ def _get_backend_version(self):
98
98
pass
99
99
100
100
if version :
101
- version = LooseVersion (version )
101
+ version = Version (version )
102
102
return version
103
103
104
104
def __str__ (self ):
@@ -178,7 +178,7 @@ def check_opset_after_tf_version(tf_version, required_opset, message=""):
178
178
""" Skip if tf_version > max_required_version """
179
179
config = get_test_config ()
180
180
reason = _append_message ("conversion requires opset {} after tf {}" .format (required_opset , tf_version ), message )
181
- skip = config .tf_version >= LooseVersion (tf_version ) and config .opset < required_opset
181
+ skip = config .tf_version >= Version (tf_version ) and config .opset < required_opset
182
182
return unittest .skipIf (skip , reason )
183
183
184
184
@@ -284,7 +284,7 @@ def check_tfjs_max_version(max_accepted_version, message=""):
284
284
except ModuleNotFoundError :
285
285
can_import = False
286
286
return unittest .skipIf (can_import and not config .skip_tfjs_tests and \
287
- tensorflowjs .__version__ > LooseVersion (max_accepted_version ), reason )
287
+ Version ( tensorflowjs .__version__ ) > Version (max_accepted_version ), reason )
288
288
289
289
def check_tfjs_min_version (min_required_version , message = "" ):
290
290
""" Skip if tjs_version < min_required_version """
@@ -296,20 +296,20 @@ def check_tfjs_min_version(min_required_version, message=""):
296
296
except ModuleNotFoundError :
297
297
can_import = False
298
298
return unittest .skipIf (can_import and not config .skip_tfjs_tests and \
299
- tensorflowjs .__version__ < LooseVersion (min_required_version ), reason )
299
+ Version ( tensorflowjs .__version__ ) < Version (min_required_version ), reason )
300
300
301
301
def check_tf_max_version (max_accepted_version , message = "" ):
302
302
""" Skip if tf_version > max_required_version """
303
303
config = get_test_config ()
304
304
reason = _append_message ("conversion requires tf <= {}" .format (max_accepted_version ), message )
305
- return unittest .skipIf (config .tf_version > LooseVersion (max_accepted_version ), reason )
305
+ return unittest .skipIf (config .tf_version > Version (max_accepted_version ), reason )
306
306
307
307
308
308
def check_tf_min_version (min_required_version , message = "" ):
309
309
""" Skip if tf_version < min_required_version """
310
310
config = get_test_config ()
311
311
reason = _append_message ("conversion requires tf >= {}" .format (min_required_version ), message )
312
- return unittest .skipIf (config .tf_version < LooseVersion (min_required_version ), reason )
312
+ return unittest .skipIf (config .tf_version < Version (min_required_version ), reason )
313
313
314
314
315
315
def skip_tf_versions (excluded_versions , message = "" ):
@@ -385,7 +385,7 @@ def check_onnxruntime_min_version(min_required_version, message=""):
385
385
config = get_test_config ()
386
386
reason = _append_message ("conversion requires onnxruntime >= {}" .format (min_required_version ), message )
387
387
return unittest .skipIf (config .is_onnxruntime_backend and
388
- config .backend_version < LooseVersion (min_required_version ), reason )
388
+ config .backend_version < Version (min_required_version ), reason )
389
389
390
390
391
391
def skip_caffe2_backend (message = "" ):
0 commit comments