Skip to content

Commit cf8c953

Browse files
Fixed conversion of Batch Norm when training=ture (#1249)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 1c9c02d commit cf8c953

File tree

2 files changed

+50
-4
lines changed

2 files changed

+50
-4
lines changed

tests/test_backend.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2221,6 +2221,27 @@ def func(x):
22212221
return tf.identity(y, name=_TFOUTPUT)
22222222
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)
22232223

2224+
@check_opset_min_version(7, "batchnorm")
2225+
def test_fused_batchnorm_training(self):
2226+
x_shape = [1, 28, 28, 2]
2227+
x_dtype = np.float32
2228+
scale_dtype = np.float32
2229+
scale_shape = [2]
2230+
# only nhwc is support on cpu for tensorflow
2231+
data_format = "NHWC"
2232+
x_val = np.random.random_sample(x_shape).astype(x_dtype)
2233+
scale_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2234+
offset_val = np.random.random_sample(scale_shape).astype(scale_dtype)
2235+
def func(x):
2236+
scale = tf.constant(scale_val, name='scale')
2237+
offset = tf.constant(offset_val, name='offset')
2238+
epsilon = 0.001
2239+
y, _, _ = fused_batch_norm(
2240+
x, scale, offset, mean=None, variance=None,
2241+
epsilon=epsilon, data_format=data_format, is_training=True)
2242+
return tf.identity(y, name=_TFOUTPUT)
2243+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val}, rtol=1e-04)
2244+
22242245
@check_opset_min_version(7, "batchnorm")
22252246
@check_tf_min_version("1.13")
22262247
def test_batchnorm(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -784,24 +784,49 @@ def version_6(cls, ctx, node, **kwargs):
784784

785785
conv_convert_inputs(ctx, node, with_kernel=False)
786786

787+
inp_shape = ctx.get_shape(node.input[0])
788+
inp_rank = len(inp_shape) if inp_shape is not None else None
787789
scale_shape = ctx.get_shape(node.input[1])
788790
mean_shape = ctx.get_shape(node.input[3])
789791
var_shape = ctx.get_shape(node.input[4])
790792
val_type = utils.map_onnx_to_numpy_type(ctx.get_dtype(node.input[1]))
791-
792-
if node.get_attr_value('is_training', 1) == 1:
793+
is_training = node.get_attr_value('is_training', True)
794+
795+
if is_training and node.get_attr_value('exponential_avg_factor', 1.0) == 1.0:
796+
# Sometimes TF uses a BatchNorm op with training = True and exponential_avg_factor = 1.0
797+
# to perform layer mean/variance normalization. In such cases, the mean/var are computed from the input.
798+
# TF allows mean/variance to be excluded only if is_training and exponential_avg_factor == 1.0
799+
utils.make_sure(inp_rank is not None, "Cannot convert node %s of type %s with input of unknown rank.",
800+
node.name, tf_type)
801+
dims = [0] + list(range(2, inp_rank))
802+
avg = ctx.make_node("ReduceMean", [node.input[0]], attr={'axes': dims, 'keepdims': True}).output[0]
803+
avg_squeezed = GraphBuilder(ctx).make_squeeze({"data": avg, "axes": dims})
804+
sub = ctx.make_node("Sub", [node.input[0], avg]).output[0]
805+
var_squeezed = ctx.make_node("ReduceSumSquare", [sub], attr={'axes': dims, 'keepdims': False}).output[0]
806+
807+
inp_shape = ctx.make_node("Shape", [node.input[0]]).output[0]
808+
dims_const = ctx.make_const(utils.make_name("axes_const"), np.array(dims, dtype=np.int64)).output[0]
809+
reduce_dims = ctx.make_node("Gather", [inp_shape, dims_const]).output[0]
810+
dims_product = ctx.make_node("ReduceProd", [reduce_dims], attr={'axes': [0], 'keepdims': False})
811+
cnt_float = ctx.make_node("Cast", [dims_product.output[0]], attr={'to': ctx.get_dtype(node.input[0])})
812+
813+
pop_var_squeezed = ctx.make_node("Div", [var_squeezed, cnt_float.output[0]]).output[0]
814+
ctx.replace_inputs(node, node.input[:3] + [avg_squeezed, pop_var_squeezed])
815+
else:
793816
logger.warning("Node %s of type %s has is_training set to true, which is not supperted. "
794817
"Please re-save the model with training set to false.",
795818
node.name, tf_type)
819+
# As long as the mean/variance estimates are provided, we should be OK
820+
is_training = False
796821

797-
if mean_shape != scale_shape and all(d >= 0 for d in scale_shape):
822+
if not is_training and mean_shape != scale_shape and all(d >= 0 for d in scale_shape):
798823
new_mean_value = np.array(np.resize(node.inputs[3].get_tensor_value(as_list=False), scale_shape),
799824
dtype=val_type)
800825
new_mean_node_name = utils.make_name(node.name)
801826
ctx.make_const(new_mean_node_name, new_mean_value)
802827
ctx.replace_input(node, node.input[3], new_mean_node_name, 3)
803828

804-
if var_shape != scale_shape and all(d >= 0 for d in scale_shape):
829+
if not is_training and var_shape != scale_shape and all(d >= 0 for d in scale_shape):
805830
new_var_value = np.array(np.resize(node.inputs[4].get_tensor_value(as_list=False), scale_shape),
806831
dtype=val_type)
807832
new_val_node_name = utils.make_name(node.name)

0 commit comments

Comments
 (0)