Skip to content

Commit 05ee1a7

Browse files
committed
support, test for fusedbatchnorm
1 parent 5338003 commit 05ee1a7

File tree

5 files changed

+111
-23
lines changed

5 files changed

+111
-23
lines changed

tests/run_pretrained_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def run_tensorflow(self, sess, inputs):
166166
@staticmethod
167167
def to_onnx(tf_graph, opset=None):
168168
"""Convert graph to tensorflow."""
169-
return process_tf_graph(tf_graph, opset=opset)
169+
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset)
170170

171171
def run_caffe2(self, name, onnx_graph, inputs):
172172
"""Run test again caffe2 backend."""
@@ -271,6 +271,7 @@ def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=No
271271
graph_def = graph_pb2.GraphDef()
272272
with open(model_path, "rb") as f:
273273
graph_def.ParseFromString(f.read())
274+
graph_def = tf2onnx.tfonnx.tf_optimize(None, inputs, self.output_names, graph_def)
274275

275276
g = tf.import_graph_def(graph_def, name='')
276277
with tf.Session(graph=g) as sess:

tests/run_pretrained_models.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ benchtf-conv:
2525
outputs:
2626
- output:0
2727

28+
benchtf-convbn:
29+
disabled: true
30+
# fails with: expects to be colocated with unknown node 'batch_normalization_1/gamma
31+
model: tests/models/convbn-layers/frozen.pb
32+
input_get: get_ramp
33+
inputs:
34+
"X:0": [1, 784]
35+
outputs:
36+
- output:0
37+
2838
benchtf-ae0:
2939
model: tests/models/ae0/frozen.pb
3040
input_get: get_ramp

tests/test_backend.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def run_onnxmsrtnext(onnx_graph, inputs, output_names, test_name):
9595
"""Run test against msrt-next backend."""
9696
import lotus
9797
model_path = os.path.join(TMPPATH, test_name + ".pb")
98+
# print(model_path)
9899
with open(model_path, "wb") as f:
99100
f.write(onnx_graph.SerializeToString())
100101
m = lotus.InferenceSession(model_path)
@@ -801,7 +802,6 @@ def test_cast(self):
801802
self.assertAllClose(expected, actual)
802803

803804
def test_onehot0(self):
804-
# no such op in onnx
805805
x_val = np.array([0, 1, 2], dtype=np.int32)
806806
depth = 5
807807
for axis in [-1, 0, 1]:
@@ -814,7 +814,7 @@ def test_onehot0(self):
814814

815815
@unittest.skip
816816
def test_onehot1(self):
817-
# no such op in onnx
817+
# only rank 1 is currently implemented
818818
x_val = np.array([[0, 2], [1, -1]], dtype=np.int32)
819819
depth = 3
820820
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
@@ -824,7 +824,6 @@ def test_onehot1(self):
824824
self.assertAllClose(expected, actual)
825825

826826
def test_onehot2(self):
827-
# no such op in onnx
828827
x_val = np.array([0, 1, 2, 1, 2, 0, 1, 2, 1, 2], dtype=np.int32)
829828
depth = 20
830829
x = tf.placeholder(tf.int32, x_val.shape, name=_TFINPUT)
@@ -924,6 +923,60 @@ def test_strided_slice2(self):
924923
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
925924
self.assertAllClose(expected, actual)
926925

926+
@unittest.skip
927+
def test_strided_slice3(self):
928+
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
929+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
930+
x_ = x[1:]
931+
output = tf.identity(x_, name=_TFOUTPUT)
932+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
933+
self.assertAllClose(expected, actual)
934+
935+
@unittest.skip
936+
def test_strided_slice4(self):
937+
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
938+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
939+
x_ = x[:2]
940+
output = tf.identity(x_, name=_TFOUTPUT)
941+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
942+
self.assertAllClose(expected, actual)
943+
944+
@unittest.skip
945+
def test_strided_slice5(self):
946+
x_val = np.arange(3*2*3).astype("float32").reshape(3, 2, 3)
947+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
948+
x_ = x[:2, 0:1, 1:]
949+
output = tf.identity(x_, name=_TFOUTPUT)
950+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
951+
self.assertAllClose(expected, actual)
952+
953+
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "fails with schema error")
954+
def test_batchnorm(self):
955+
x_shape = [1, 28, 28, 2]
956+
x_dtype = np.float32
957+
scale_dtype = np.float32
958+
scale_shape = [2]
959+
# only nhwc is support on cpu for tensorflow
960+
data_format = "NHWC"
961+
x_val = np.random.random_sample(x_shape).astype(x_dtype)
962+
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
963+
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
964+
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
965+
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
966+
967+
x = tf.placeholder(tf.float32, x_val.shape, name=_TFINPUT)
968+
scale = tf.constant(scale_val, name='scale')
969+
offset = tf.constant(offset_val, name='offset')
970+
mean = tf.constant(mean_val, name='mean')
971+
var = tf.constant(var_val, name='variance')
972+
epsilon = 0.001
973+
y, _, _ = tf.nn.fused_batch_norm(
974+
x, scale, offset, mean=mean, variance=var,
975+
epsilon=epsilon, data_format=data_format, is_training=False)
976+
output = tf.identity(y, name=_TFOUTPUT)
977+
actual, expected = self._run(output, {x: x_val}, {_INPUT: x_val})
978+
self.assertAllClose(expected, actual, rtol=1e-04)
979+
927980
@unittest.skipIf(BACKEND in ["caffe2", "onnxmsrt"], "not correctly supported")
928981
def test_resize_nearest_neighbor(self):
929982
x_shape = [1, 15, 20, 2]
@@ -964,10 +1017,13 @@ def test_fill(self):
9641017
parser.add_argument('--backend', default='caffe2',
9651018
choices=["caffe2", "onnxmsrt", "onnxmsrtnext", "onnx-tensorflow"],
9661019
help="backend to test against")
1020+
parser.add_argument('--opset', default=OPSET,
1021+
help="opset to test against")
9671022
parser.add_argument('unittest_args', nargs='*')
9681023

9691024
args = parser.parse_args()
9701025
BACKEND = args.backend
1026+
OPSET = args.opset
9711027
# Now set the sys.argv to the unittest_args (leaving sys.argv[0] alone)
9721028
sys.argv[1:] = args.unittest_args
9731029
unittest.main()

tf2onnx/tfonnx.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,8 @@ def reshape_op5(ctx, node, name, args):
305305
NCHW_TO_HWCN = [2, 3, 1, 0]
306306

307307

308-
def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None):
308+
def conv_convert_inputs(ctx, node, with_kernel=False, new_kernel_shape=None,
309+
input_indices=None, output_indices=None):
309310
"""Convert input and kernel from tensorflow to onnx. This maybe require to
310311
to insert transpose ops for input, kernel and output unless they are constants
311312
and we can transpose the constant.
@@ -324,25 +325,32 @@ def calc_shape(a, b):
324325
return [a[b[i]] for i in b]
325326
return None
326327

328+
if input_indices is None:
329+
input_indices = [0]
330+
if output_indices is None:
331+
output_indices = [0]
332+
327333
nodes = []
328334

329335
if node.is_nhwc():
330336
# transpose input if needed, no need to record shapes on input
331-
if node.inputs[0].is_const():
332-
# if input is a constant, transpose that one
333-
parent = node.inputs[0]
334-
if not parent.data_format:
335-
val = parent.get_tensor_value()
336-
parent.set_tensor_value(val.transpose(NHWC_TO_NCHW))
337-
parent.data_format = "NCHW"
338-
else:
339-
# if input comes from a op, insert transpose op
340-
input_name = node.input[0]
341-
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
342-
transpose.set_attr("perm", NHWC_TO_NCHW)
343-
transpose.inserted_nchw = True
344-
ctx.set_shape(transpose.output[0], calc_shape(ctx.get_shape(input_name), NHWC_TO_NCHW))
345-
nodes.append(transpose)
337+
for idx in input_indices:
338+
if node.inputs[idx].is_const():
339+
# if input is a constant, transpose that one
340+
parent = node.inputs[idx]
341+
if not parent.data_format:
342+
val = parent.get_tensor_value()
343+
parent.set_tensor_value(val.transpose(NHWC_TO_NCHW))
344+
parent.data_format = "NCHW"
345+
else:
346+
# if input comes from a op, insert transpose op
347+
input_name = node.input[idx]
348+
transpose = ctx.insert_new_node_on_input(node, "Transpose", input_name)
349+
transpose.set_attr("perm", NHWC_TO_NCHW)
350+
transpose.inserted_nchw = True
351+
if idx == 0:
352+
ctx.set_shape(transpose.output[0], calc_shape(ctx.get_shape(input_name), NHWC_TO_NCHW))
353+
nodes.append(transpose)
346354

347355
# kernel mist to be transposed
348356
if with_kernel:
@@ -388,12 +396,13 @@ def calc_shape(a, b):
388396
# transpose outputs if needed
389397
if node.is_nhwc():
390398
# TODO: what if len(output) > 0 ?
391-
for i, output_name in enumerate(node.output):
399+
for idx in output_indices:
400+
output_name = node.output[idx]
392401
op_name = utils.make_name(node.name)
393402
transpose = ctx.insert_new_node_on_output("Transpose", output_name, name=op_name)
394403
transpose.set_attr("perm", NCHW_TO_NHWC)
395404
transpose.inserted_nchw = True
396-
ctx.set_shape(transpose.output[0], calc_shape(ctx.get_shape(node.output[0]), NCHW_TO_NHWC))
405+
ctx.set_shape(transpose.output[0], calc_shape(ctx.get_shape(node.output[idx]), NCHW_TO_NHWC))
397406
nodes.append(transpose)
398407
return nodes
399408

@@ -925,6 +934,17 @@ def onehot_op(ctx, node, name, args):
925934
return [node, transpose_op]
926935
return node
927936

937+
def fused_batchnorm_op7(ctx, node, name, args):
938+
node.type = "BatchNormalization"
939+
# tf inputs: x, scale, bias, mean, variance
940+
# tf outputs: y, batch_mean, batch_var
941+
# a: data_format, epsilon, is_training
942+
# onnx inputs: X, scale, B, mean, variance, attributes: epsilon, momentum=0.9, spatial : 1
943+
# output: mean, var, savedmean, savedvar,
944+
nodes = conv_convert_inputs(ctx, node, with_kernel=False)
945+
return nodes
946+
947+
928948

929949
# pylint: enable=W0613,C0111,W0612
930950

@@ -1048,6 +1068,7 @@ def onehot_op(ctx, node, name, args):
10481068
"Sin": (direct_op, []),
10491069
"Tan": (direct_op, []),
10501070
"Multinomial": (multinomial_op, []),
1071+
"FusedBatchNorm": (fused_batchnorm_op7, []),
10511072
}
10521073

10531074
_OPSETS = [

tf2onnx/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
'dtype', 'output_shape', 'spatial', 'split', 'input_forget', 'keepdims', 'transA', 'auto_pad', 'border', 'low',
7777
'linear_before_reset', 'height_scale', 'output_padding', 'shape', 'kernel_shape', 'epsilon', 'size', 'starts',
7878
'direction', 'max', 'clip', 'across_channels', 'value', 'strides', 'extra_shape', 'scales', 'k', 'sample_size',
79-
'blocksize'
79+
'blocksize', 'epsilon', 'momentum'
8080
}
8181

8282

0 commit comments

Comments
 (0)