Skip to content

Commit 76924df

Browse files
committed
Use packaging library to avoid DeprecationWarning from distutils
Signed-off-by: Deyu Huang <[email protected]>
1 parent 3bd3081 commit 76924df

File tree

7 files changed

+26
-26
lines changed

7 files changed

+26
-26
lines changed

tests/common.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import unittest
1010
from collections import defaultdict
1111

12-
from distutils.version import LooseVersion
12+
from packaging.version import Version
1313
from parameterized import parameterized
1414
import numpy as np
1515
import tensorflow as tf
@@ -98,7 +98,7 @@ def _get_backend_version(self):
9898
pass
9999

100100
if version:
101-
version = LooseVersion(version)
101+
version = Version(version)
102102
return version
103103

104104
def __str__(self):
@@ -178,7 +178,7 @@ def check_opset_after_tf_version(tf_version, required_opset, message=""):
178178
""" Skip if tf_version > max_required_version """
179179
config = get_test_config()
180180
reason = _append_message("conversion requires opset {} after tf {}".format(required_opset, tf_version), message)
181-
skip = config.tf_version >= LooseVersion(tf_version) and config.opset < required_opset
181+
skip = config.tf_version >= Version(tf_version) and config.opset < required_opset
182182
return unittest.skipIf(skip, reason)
183183

184184

@@ -284,7 +284,7 @@ def check_tfjs_max_version(max_accepted_version, message=""):
284284
except ModuleNotFoundError:
285285
can_import = False
286286
return unittest.skipIf(can_import and not config.skip_tfjs_tests and \
287-
tensorflowjs.__version__ > LooseVersion(max_accepted_version), reason)
287+
Version(tensorflowjs.__version__) > Version(max_accepted_version), reason)
288288

289289
def check_tfjs_min_version(min_required_version, message=""):
290290
""" Skip if tjs_version < min_required_version """
@@ -296,20 +296,20 @@ def check_tfjs_min_version(min_required_version, message=""):
296296
except ModuleNotFoundError:
297297
can_import = False
298298
return unittest.skipIf(can_import and not config.skip_tfjs_tests and \
299-
tensorflowjs.__version__ < LooseVersion(min_required_version), reason)
299+
Version(tensorflowjs.__version__) < Version(min_required_version), reason)
300300

301301
def check_tf_max_version(max_accepted_version, message=""):
302302
""" Skip if tf_version > max_required_version """
303303
config = get_test_config()
304304
reason = _append_message("conversion requires tf <= {}".format(max_accepted_version), message)
305-
return unittest.skipIf(config.tf_version > LooseVersion(max_accepted_version), reason)
305+
return unittest.skipIf(config.tf_version > Version(max_accepted_version), reason)
306306

307307

308308
def check_tf_min_version(min_required_version, message=""):
309309
""" Skip if tf_version < min_required_version """
310310
config = get_test_config()
311311
reason = _append_message("conversion requires tf >= {}".format(min_required_version), message)
312-
return unittest.skipIf(config.tf_version < LooseVersion(min_required_version), reason)
312+
return unittest.skipIf(config.tf_version < Version(min_required_version), reason)
313313

314314

315315
def skip_tf_versions(excluded_versions, message=""):
@@ -385,7 +385,7 @@ def check_onnxruntime_min_version(min_required_version, message=""):
385385
config = get_test_config()
386386
reason = _append_message("conversion requires onnxruntime >= {}".format(min_required_version), message)
387387
return unittest.skipIf(config.is_onnxruntime_backend and
388-
config.backend_version < LooseVersion(min_required_version), reason)
388+
config.backend_version < Version(min_required_version), reason)
389389

390390

391391
def skip_caffe2_backend(message=""):

tests/run_pretrained_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import zipfile
1818
import random
1919
from collections import namedtuple
20-
from distutils.version import LooseVersion
20+
from packaging.version import Version
2121

2222

2323
import yaml
@@ -789,7 +789,7 @@ def main():
789789
continue
790790

791791
if t.tf_min_version:
792-
if tf_utils.get_tf_version() < LooseVersion(str(t.tf_min_version)):
792+
if tf_utils.get_tf_version() < Version(str(t.tf_min_version)):
793793
logger.info("Skip %s: %s %s", test, "Min TF version needed:", t.tf_min_version)
794794
continue
795795

tests/test_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55

66
import os
77
import unittest
8-
from packaging.version import Version
98
from itertools import product
109

1110
import numpy as np
1211
from numpy.testing import assert_almost_equal
12+
from packaging.version import Version
1313
import tensorflow as tf
1414

1515
from tensorflow.python.ops import lookup_ops

tf2onnx/convert.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import argparse
1111
import os
1212
import sys
13-
from distutils.version import LooseVersion
13+
from packaging.version import Version
1414

1515
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "3"
1616

@@ -20,7 +20,7 @@
2020
from tf2onnx import constants, logging, utils, optimizer
2121
from tf2onnx import tf_loader
2222
from tf2onnx.graph import ExternalTensorStorage
23-
from tf2onnx.tf_utils import compress_graph_def
23+
from tf2onnx.tf_utils import compress_graph_def, get_tf_version
2424

2525

2626

@@ -431,7 +431,7 @@ def from_keras(model, input_signature=None, opset=None, custom_ops=None, custom_
431431
Returns:
432432
An ONNX model_proto and an external_tensor_storage dict.
433433
"""
434-
if LooseVersion(tf.__version__) < "2.0":
434+
if get_tf_version() < Version("2.0"):
435435
return _from_keras_tf1(model, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw,
436436
outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
437437

@@ -540,7 +540,7 @@ def from_function(function, input_signature=None, opset=None, custom_ops=None, c
540540
Returns:
541541
An ONNX model_proto and an external_tensor_storage dict.
542542
"""
543-
if LooseVersion(tf.__version__) < "2.0":
543+
if get_tf_version() < Version("2.0"):
544544
raise NotImplementedError("from_function requires tf-2.0 or newer")
545545

546546
if input_signature is None:

tf2onnx/shape_inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
"""
77

88
import logging
9-
from distutils.version import LooseVersion
109
from collections import defaultdict
1110
import numpy as np
11+
from packaging.version import Version
1212
from tf2onnx import utils
1313
from tf2onnx.tf_utils import get_tf_tensor_shape, get_tf_const_value, get_tf_shape_attr, get_tf_version
1414
from tf2onnx.tf_loader import tf_reload_graph
@@ -32,7 +32,7 @@ def infer_shape(tf_graph, shape_override):
3232

3333
op_outputs_with_none_shape = check_shape_for_tf_graph(tf_graph)
3434
if op_outputs_with_none_shape:
35-
if get_tf_version() > LooseVersion("1.5.0"):
35+
if get_tf_version() > Version("1.5.0"):
3636
for op, outs in op_outputs_with_none_shape.items():
3737
logger.warning(
3838
"Cannot infer shape for %s: %s",

tf2onnx/tf_loader.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import logging
77
import uuid
8-
from distutils.version import LooseVersion
8+
from packaging.version import Version
99

1010
import tensorflow as tf
1111
import numpy as np
@@ -75,7 +75,7 @@ def not_implemented_tf_placeholder(*args, **kwargs):
7575
tf_placeholder = tf.compat.v1.placeholder
7676
tf_placeholder_with_default = tf.compat.v1.placeholder_with_default
7777
extract_sub_graph = tf.compat.v1.graph_util.extract_sub_graph
78-
elif LooseVersion(tf.__version__) >= "1.13":
78+
elif Version(tf.__version__) >= Version("1.13"):
7979
# 1.13 introduced the compat namespace
8080
tf_reset_default_graph = tf.compat.v1.reset_default_graph
8181
tf_global_variables = tf.compat.v1.global_variables
@@ -162,7 +162,7 @@ def make_tensor_proto_wrapped(values, dtype=None, shape=None, verify_shape=False
162162

163163
try:
164164
function_converter = _FunctionConverterData
165-
if LooseVersion(tf.__version__) >= "2.6.0":
165+
if Version(tf.__version__) >= Version("2.6.0"):
166166
from tensorflow.python.eager import context
167167
from tensorflow.python.framework.convert_to_constants import _FunctionConverterDataInEager, \
168168
_FunctionConverterDataInGraph
@@ -267,7 +267,7 @@ def from_function(func, input_names, output_names, large_model=False):
267267
return convert_variables_to_constants_large_model(func)
268268

269269
try:
270-
if get_tf_version() < LooseVersion("2.2"):
270+
if get_tf_version() < Version("2.2"):
271271
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False)
272272
else:
273273
frozen_func = convert_variables_to_constants_v2(func, lower_control_flow=False, aggressive_inlining=True)
@@ -687,7 +687,7 @@ def tf_optimize_grappler(input_names, output_names, graph_def):
687687
'constfold', 'function'
688688
]
689689

690-
if LooseVersion(tf.__version__) >= "2.5":
690+
if Version(tf.__version__) >= Version("2.5"):
691691
# This flag disables folding QDQ nodes around constants in the network (eg: around conv/FC weights)
692692
rewrite_options.experimental_disable_folding_quantization_emulation = True
693693

@@ -710,7 +710,7 @@ def tf_optimize(input_names, output_names, graph_def):
710710
[utils.node_name(i) for i in output_names]
711711
graph_def = extract_sub_graph(graph_def, needed_names)
712712

713-
want_grappler = is_tf2() or LooseVersion(tf.__version__) >= "1.15"
713+
want_grappler = is_tf2() or Version(tf.__version__) >= Version("1.15")
714714
if want_grappler:
715715
graph_def = tf_optimize_grappler(input_names, output_names, graph_def)
716716
else:
@@ -730,7 +730,7 @@ def tf_optimize(input_names, output_names, graph_def):
730730
def tf_reload_graph(tf_graph):
731731
"""Invoke tensorflow cpp shape inference by reloading graph_def."""
732732
# invoke c api if tf version is below 1.8
733-
if get_tf_version() < LooseVersion("1.8"):
733+
if get_tf_version() < Version("1.8"):
734734
logger.debug(
735735
"On TF < 1.8, graph is constructed by python API, "
736736
"which doesn't invoke shape inference, please set "

tf2onnx/tf_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import collections
9-
from distutils.version import LooseVersion
9+
from packaging.version import Version
1010

1111
import numpy as np
1212
import tensorflow as tf
@@ -121,7 +121,7 @@ def get_tf_node_attr(node, name):
121121

122122

123123
def get_tf_version():
124-
return LooseVersion(tf.__version__)
124+
return Version(tf.__version__)
125125

126126
def compress_graph_def(graph_def):
127127
"""

0 commit comments

Comments
 (0)