Skip to content

Commit 5497030

Browse files
onnx requires all inputs for bn to be the same T (#1344)
* onnx requires all inputs for bn to be the same T Signed-off-by: Guenther Schmuelling <[email protected]> * trainig space Signed-off-by: Guenther Schmuelling <[email protected]> * pylint Signed-off-by: Guenther Schmuelling <[email protected]> Co-authored-by: TomWildenhain-Microsoft <[email protected]>
1 parent 6f87228 commit 5497030

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

tests/test_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2335,6 +2335,27 @@ def func(x):
23352335
return tf.identity(y, name=_TFOUTPUT)
23362336
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)
23372337

2338+
@skip_tflite("tflite converts aborts")
2339+
@check_opset_min_version(11, "batchnorm")
2340+
@check_tf_min_version("2.4")
2341+
def test_batchnorm_mixed(self):
2342+
x_shape = [1, 32, 32, 2]
2343+
x_dtype = np.float16
2344+
scale_dtype = np.float32
2345+
scale_shape = [2]
2346+
x_val = np.random.random_sample(x_shape).astype(x_dtype)
2347+
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2348+
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2349+
mean_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2350+
var_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2351+
def func(x, mean, offset, var):
2352+
scale = tf.constant(scale_val, name='scale')
2353+
y = tf.raw_ops.FusedBatchNormV3(x=x, scale=scale, offset=offset, mean=mean, variance=var,
2354+
is_training=False, name=_TFOUTPUT)
2355+
return y
2356+
self._run_test_case(func, [_OUTPUT],
2357+
{_INPUT: x_val, _INPUT1: mean_val, _INPUT2: offset_val, _INPUT3: var_val})
2358+
23382359
@check_opset_min_version(7, "batchnorm")
23392360
@check_tf_min_version("1.13")
23402361
def test_batchnorm(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,21 @@ def version_6(cls, ctx, node, **kwargs):
780780
# output: y, mean, var, savedmean, savedvar,
781781
# detach unused outputs. While we could let the unused outputs dangle,
782782
# some runtimes like pytorch/caffe2 do complain about it.
783+
784+
# onnx batchnorm requires same T for all inputs
785+
mean_type = ctx.get_dtype(node.input[3])
786+
x_dtype = ctx.get_dtype(node.input[0])
787+
if x_dtype != mean_type:
788+
# TODO: this works but more efficient would be to flip the other inputs. We'd need to check
789+
# TODO: first if this works with the onnx implementation so its a later for now
790+
ctx.insert_new_node_on_input(node, "Cast", node.input[0], to=mean_type)
791+
# casting the input[0] will change the output dtype of bn so we need to cast back
792+
cast_back_node = ctx.insert_new_node_on_output("Cast", node.output[0],
793+
name=utils.make_name(node.name) + "_castback",
794+
to=x_dtype)
795+
ctx.set_dtype(cast_back_node.output[0], x_dtype)
796+
ctx.copy_shape(node.name, cast_back_node.output[0])
797+
783798
consumers = [ctx.find_output_consumers(output_name) for output_name in node.output[1:]]
784799
if not any(consumers):
785800
new_output = [node.output[0]]

0 commit comments

Comments
 (0)