Skip to content

Commit 06339ac

Browse files
authored
Merge pull request #70 from onnx/gs/py27
changes to make it work on py2.7
2 parents 80e8177 + 549c6b8 commit 06339ac

File tree

13 files changed

+72
-32
lines changed

13 files changed

+72
-32
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def run(self):
7474
version=VersionInfo.version,
7575
description='Tensorflow to ONNX converter',
7676
setup_requires=['pytest-runner'],
77-
tests_require=['pytest', 'pytest-cov', 'psutil', 'graphviz', 'pyyaml'],
77+
tests_require=['requests', 'pytest', 'pytest-cov', 'psutil', 'graphviz', 'pyyaml'],
7878
cmdclass=cmdclass,
7979
packages=find_packages(),
8080

tests/run_pretrained_models.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4+
from __future__ import absolute_import
5+
from __future__ import division
6+
from __future__ import print_function
7+
48
import argparse
59
import os
610
import tarfile
711
import time
812
import tempfile
9-
import urllib
10-
import urllib.request
13+
import requests
1114
import zipfile
1215

1316
import PIL.Image
@@ -136,7 +139,11 @@ def download_file(self):
136139
os.makedirs(dir_name, exist_ok=True)
137140
fpath = os.path.join(dir_name, fname)
138141
if not os.path.exists(fpath):
139-
urllib.request.urlretrieve(url, fpath)
142+
response = requests.get(url)
143+
if response.status_code not in [200]:
144+
response.raise_for_status()
145+
with open(fpath, "wb") as f:
146+
f.write(response.content)
140147
model_path = os.path.join(dir_name, self.local)
141148
if not os.path.exists(model_path):
142149
if ftype == 'tgz':

tests/test_backend.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4+
from __future__ import division
5+
from __future__ import print_function
6+
47
import argparse
58
import os
69
import sys
@@ -325,7 +328,6 @@ def test_conv2d_6(self):
325328
expected, actual = self._conv_test(x_val, kernel_val, strides=strides, padding="VALID")
326329
self.assertAllClose(expected, actual, rtol=1e-05)
327330

328-
329331
def test_conv2d_7(self):
330332
x_shape = [1, 35, 35, 288] # out: [1, 17, 17, 384]
331333
kernel_shape = [3, 3, 288, 384]
@@ -567,7 +569,7 @@ def test_logicaland(self):
567569
x2 = tf.placeholder(tf.bool, [2, 2], name=_TFINPUT1)
568570
mi = tf.logical_and(x1, x2)
569571
output = tf.identity(mi, name=_TFOUTPUT)
570-
actual, expected = self._run(output, {x1: x_val1, x2: x_val2}, {_INPUT: x_val1, _INPUT1: x_val2,})
572+
actual, expected = self._run(output, {x1: x_val1, x2: x_val2}, {_INPUT: x_val1, _INPUT1: x_val2})
571573
self.assertAllClose(expected, actual)
572574

573575
def test_greater(self):
@@ -577,7 +579,7 @@ def test_greater(self):
577579
x2 = tf.placeholder(tf.float32, [2, 2], name=_TFINPUT1)
578580
mi = tf.greater(x1, x2)
579581
output = tf.identity(mi, name=_TFOUTPUT)
580-
actual, expected = self._run(output, {x1: x_val1, x2: x_val2}, {_INPUT: x_val1, _INPUT1: x_val2,})
582+
actual, expected = self._run(output, {x1: x_val1, x2: x_val2}, {_INPUT: x_val1, _INPUT1: x_val2})
581583
self.assertAllClose(expected, actual)
582584

583585
def test_sequeeze(self):
@@ -780,9 +782,7 @@ def test_reducemean(self):
780782
@unittest.skip
781783
def test_slice1(self):
782784
# FIXME: only 1 dimension supported by caffe2 and msrt
783-
x_val = np.array([[[1, 1, 1], [2, 2, 2]],
784-
[[3, 3, 3], [4, 4, 4]],
785-
[[5, 5, 5], [6, 6, 6]]], dtype=np.float32)
785+
x_val = np.array([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]], dtype=np.float32)
786786
t1 = tf.constant([1, 0, 0], dtype=tf.int32)
787787
t2 = tf.constant([1, 1, 3], dtype=tf.int32)
788788
x0 = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
@@ -920,7 +920,7 @@ def test_topk(self):
920920
x_val = np.arange(3*2*3).astype("float32")
921921
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
922922
values, indices = tf.nn.top_k(x, 5, sorted=True)
923-
output = tf.identity(values, name=_TFOUTPUT)
923+
output = tf.identity(values, name=_TFOUTPUT)
924924
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
925925
self.assertAllClose(expected, actual)
926926

@@ -949,7 +949,7 @@ def test_space_to_depth(self):
949949
x_val = make_xval([1, 2, 2, 1])
950950
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
951951
x_ = tf.space_to_depth(x, block_size=2)
952-
output = tf.identity(x_, name=_TFOUTPUT)
952+
output = tf.identity(x_, name=_TFOUTPUT)
953953
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
954954
self.assertAllClose(expected, actual)
955955

@@ -958,7 +958,7 @@ def test_addn(self):
958958
x_val = np.arange(3*2*3).astype("float32")
959959
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
960960
x_ = tf.add_n([x, x, x])
961-
output = tf.identity(x_, name=_TFOUTPUT)
961+
output = tf.identity(x_, name=_TFOUTPUT)
962962
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
963963
self.assertAllClose(expected, actual)
964964

@@ -1020,7 +1020,7 @@ def test_batchnorm(self):
10201020
x_shape = [1, 28, 28, 2]
10211021
x_dtype = np.float32
10221022
scale_dtype = np.float32
1023-
scale_shape = [2]
1023+
scale_shape = [2]
10241024
# only nhwc is support on cpu for tensorflow
10251025
data_format = "NHWC"
10261026
x_val = np.random.random_sample(x_shape).astype(x_dtype)

tests/test_graph.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4+
from __future__ import division
5+
from __future__ import print_function
6+
47
import unittest
58
from collections import namedtuple
69

@@ -116,7 +119,7 @@ def test_add(self):
116119
_ = tf.identity(x_, name="output")
117120
g = process_tf_graph(sess.graph)
118121
self.assertEqual(
119-
'digraph { Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> ' \
122+
'digraph { Add [op_type=Add] output [op_type=Identity] input1:0 -> Add input2:0 -> '
120123
'Add Add:0 -> output }',
121124
onnx_to_graphviz(g))
122125

@@ -200,9 +203,9 @@ def test_conv2d(self):
200203

201204
g = process_tf_graph(sess.graph)
202205
self.assertEqual(
203-
'digraph { Conv2D__2 [op_type=Transpose] kernel [op_type=Reshape] Conv2D__3 [op_type=Transpose] ' \
204-
'Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] output [op_type=Identity] input1:0 -> ' \
205-
'Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel kernel:0 -> Conv2D__3 Conv2D__2:0 -> Conv2D ' \
206+
'digraph { Conv2D__2 [op_type=Transpose] kernel [op_type=Reshape] Conv2D__3 [op_type=Transpose] '
207+
'Conv2D [op_type=Conv] Conv2D__4 [op_type=Transpose] output [op_type=Identity] input1:0 -> '
208+
'Conv2D__2 k:0 -> kernel "kernel/shape":0 -> kernel kernel:0 -> Conv2D__3 Conv2D__2:0 -> Conv2D '
206209
'Conv2D__3:0 -> Conv2D Conv2D:0 -> Conv2D__4 Conv2D__4:0 -> output }',
207210
onnx_to_graphviz(g))
208211

@@ -234,7 +237,7 @@ def test_reshape(self):
234237
_ = tf.identity(x_, name="output")
235238
g = process_tf_graph(sess.graph)
236239
self.assertEqual(
237-
'digraph { Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape ' \
240+
'digraph { Reshape [op_type=Reshape] output [op_type=Identity] input1:0 -> Reshape '
238241
'"Reshape/shape":0 -> Reshape Reshape:0 -> output }',
239242
onnx_to_graphviz(g))
240243

@@ -257,7 +260,7 @@ def rewrite_test(g, ops):
257260
_ = tf.identity(x_, name="output")
258261
g = process_tf_graph(sess.graph, custom_rewriter=[rewrite_test])
259262
self.assertEqual(
260-
'digraph { Add [op_type=Mul] output [op_type=Identity] input1:0 -> ' \
263+
'digraph { Add [op_type=Mul] output [op_type=Identity] input1:0 -> '
261264
'Add input1:0 -> Add Add:0 -> output }',
262265
onnx_to_graphviz(g))
263266

tests/test_internals.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4+
from __future__ import division
5+
from __future__ import print_function
6+
47
import unittest
58
from collections import namedtuple
69

@@ -136,7 +139,6 @@ def test_rewrite_subgraph(self):
136139
'n3:0 -> ReplacedOp__2 ReplacedOp__2:0 -> n6 }'
137140
self.assertEqual(expected, result)
138141

139-
140142
def test_match_flipped(self):
141143
n1 = helper.make_node("Sub", ["i1", "i1"], ["n1:0"], name="n1")
142144
n2 = helper.make_node("Add", ["i2", "i2"], ["n2:0"], name="n2")

tf2onnx/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT license.
33

4-
from __future__ import absolute_import
54
from __future__ import division
65
from __future__ import print_function
76
from __future__ import unicode_literals
87

98
from .version import version as __version__
109

11-
__all__ = ["utils", "graph_matcher", "graph", "tf2onnx"]
10+
__all__ = ["utils", "graph_matcher", "graph", "tfonnx"]
11+
12+
import tf2onnx
1213
from tf2onnx import tfonnx, utils, graph, graph_matcher

tf2onnx/convert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
python -m tf2onnx.convert : tool to convert a frozen tensorflow to onnx
66
"""
77

8+
from __future__ import division
9+
from __future__ import print_function
10+
811
import argparse
912
import sys
1013

tf2onnx/graph.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44
"""
55
tf2onnx.graph - class to manage graph manipulation on top of onnx
66
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
11+
import tf2onnx
712
from onnx import numpy_helper, optimizer, ModelProto, defs, OperatorSetIdProto
8-
from tf2onnx import utils, tfonnx, __version__
13+
from tf2onnx import utils, __version__
914
from tf2onnx.utils import *
1015

1116

@@ -31,7 +36,7 @@ def __init__(self, node, graph):
3136
for a in node.attribute:
3237
self._attr[a.name] = a
3338
# try to find a dtype for this node
34-
dtype = graph._dtypes.get(node.name)
39+
dtype = graph.get_dtype(node.name)
3540
if not dtype:
3641
dtype = self._attr.get("dtype")
3742
if dtype:
@@ -230,7 +235,7 @@ def __init__(self, nodes, output_shapes=None, dtypes=None, target=None, opset=No
230235
dtypes: dict of tensorflow dtype
231236
"""
232237
if target is None:
233-
target = tfonnx.DEFAULT_TARGET
238+
target = tf2onnx.tfonnx.DEFAULT_TARGET
234239
self._nodes = []
235240
self._initializers = {}
236241
self._nodes_by_name = {}

tf2onnx/tfonnx.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
"""
55
tf2onnx.tf2onnx - rewrite tensorflow graph to onnx graph
66
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
711
import collections
812
import logging
913
import sys
@@ -87,7 +91,7 @@ def tensorflow_to_onnx(graph):
8791
onnx_tensor = utils.tf_to_onnx_tensor(node.get_attr(a), name=node.name + ":0")
8892
attr[a] = onnx_tensor
8993
elif a == "DstT":
90-
attr["to"] = utils.map_tf_dtype(node.get_attr("DstT"))
94+
attr["to"] = utils.map_tf_dtype(node.get_attr("DstT"))
9195
elif a == "SrcT":
9296
continue
9397
elif a in ignored_attr:
@@ -127,6 +131,7 @@ def _convert_shapenode_to_int64(ctx, node, input_number):
127131

128132
# pylint: disable=W0613,C0111,W0612
129133

134+
130135
def no_op(ctx, node, name, args):
131136
"""Skip node."""
132137
return None
@@ -304,12 +309,14 @@ def reshape_op5(ctx, node, name, args):
304309
HWCN_TO_NCHW = [3, 2, 0, 1]
305310
NCHW_TO_HWCN = [2, 3, 1, 0]
306311

312+
307313
def spatial_map(shape, perm):
308314
new_shape = shape[:]
309315
for i in perm:
310316
new_shape[i] = shape[perm[i]]
311317
return new_shape
312318

319+
313320
def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
314321
input_indices=None, output_indices=None):
315322
"""Convert input and kernel from tensorflow to onnx. This maybe require to
@@ -988,6 +995,7 @@ def onehot_op(ctx, node, name, args):
988995
return [node, transpose_op]
989996
return node
990997

998+
991999
def fused_batchnorm_op7(ctx, node, name, args):
9921000
node.type = "BatchNormalization"
9931001
# tf inputs: x, scale, bias, mean, variance
@@ -999,7 +1007,6 @@ def fused_batchnorm_op7(ctx, node, name, args):
9991007
return nodes
10001008

10011009

1002-
10031010
# pylint: enable=W0613,C0111,W0612
10041011

10051012
# map tensorflow ops to onnx ops. The format below is
@@ -1132,6 +1139,7 @@ def fused_batchnorm_op7(ctx, node, name, args):
11321139
(7, _OPSET_7),
11331140
]
11341141

1142+
11351143
def rewrite_random_uniform(g, ops):
11361144
pattern = \
11371145
OpTypePattern('Add', name='output', inputs=[
@@ -1315,7 +1323,7 @@ def tensorflow_onnx_mapping(g, continue_on_error, custom_op_handlers):
13151323
def tf_optimize(sess, inputs, outputs, graph_def):
13161324
"""Optimize tensorflow graph for inference."""
13171325
transforms = [
1318-
#"fold_constants(ignore_errors=true)",
1326+
# "fold_constants(ignore_errors=true)",
13191327
"fold_batch_norms",
13201328
"fold_old_batch_norms",
13211329
]

tf2onnx/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
tf2onnx.utils - misc utilities for tf2onnx
66
"""
77

8+
from __future__ import division
9+
from __future__ import print_function
10+
811
import numpy as np
912
import tensorflow as tf
1013
from onnx import helper, onnx_pb
@@ -149,6 +152,7 @@ def get_shape(node):
149152
pass
150153
return dims
151154

155+
152156
def map_tf_dtype(dtype):
153157
if dtype:
154158
dtype = TF_TO_ONNX_DTYPE[dtype]
@@ -157,8 +161,6 @@ def map_tf_dtype(dtype):
157161

158162
def node_name(name):
159163
"""Get node name without io#."""
160-
# FIXME: do we use this ?
161-
assert isinstance(name, str)
162164
pos = name.find(":")
163165
if pos >= 0:
164166
return name[:pos]

0 commit comments

Comments
 (0)