Skip to content

Commit 9c63b7c

Browse files
authored
[cherry-pick] add bn momentum variable (#21435)
* batch_norm momentum support variable. test=develop
1 parent 5c7c6b1 commit 9c63b7c

File tree

5 files changed

+139
-46
lines changed

5 files changed

+139
-46
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
5858
const DataLayout data_layout = framework::StringToDataLayout(
5959
ctx->Attrs().Get<std::string>("data_layout"));
6060

61+
if (ctx->IsRuntime() && ctx->HasInput("MomentumTensor")) {
62+
auto mom = ctx->Inputs("MomentumTensor");
63+
PADDLE_ENFORCE_EQ(mom.size(), 1,
64+
platform::errors::InvalidArgument(
65+
"Input(MomentumTensor) size must be 1"));
66+
}
67+
6168
PADDLE_ENFORCE_GE(
6269
x_dims.size(), 2,
6370
"ShapeError: the dimension of input X must greater than or equal to 2."
@@ -173,6 +180,11 @@ void BatchNormOpMaker::Make() {
173180
AddInput("Variance",
174181
"The global variance (for training) "
175182
"or estimated Variance (for testing)");
183+
AddInput("MomentumTensor",
184+
"(Tensor<float32>, optional) If provided, batch_norm will "
185+
"use this as momentum, this has a higher priority than "
186+
"attr(momentum), the shape of this tensor MUST BE [1].")
187+
.AsDispensable();
176188
AddOutput("Y", "result after normalization");
177189
AddOutput("MeanOut",
178190
"Share memory with Mean. "
@@ -221,7 +233,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
221233
public:
222234
void Compute(const framework::ExecutionContext &ctx) const override {
223235
const float epsilon = ctx.Attr<float>("epsilon");
224-
const float momentum = ctx.Attr<float>("momentum");
236+
float momentum = ctx.Attr<float>("momentum");
225237
const bool is_test = ctx.Attr<bool>("is_test");
226238
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
227239

@@ -306,6 +318,13 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
306318
PADDLE_THROW("Unknown storage order: %s", data_layout_str);
307319
}
308320

321+
// if MomentumTensor is set, use MomentumTensor value, momentum
322+
// is only used in this training branch
323+
if (ctx.HasInput("MomentumTensor")) {
324+
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
325+
momentum = mom_tensor->data<float>()[0];
326+
}
327+
309328
running_mean_arr =
310329
running_mean_arr * momentum + saved_mean_e * (1. - momentum);
311330
running_var_arr =

paddle/fluid/operators/batch_norm_op.cu

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
4343
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
4444
"It must use CUDAPlace.");
4545
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
46-
const float momentum = ctx.Attr<float>("momentum");
46+
float momentum = ctx.Attr<float>("momentum");
4747
const bool is_test = ctx.Attr<bool>("is_test");
4848
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
4949
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
@@ -133,6 +133,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
133133
est_mean->template data<BatchNormParamType<T>>(),
134134
est_var->template data<BatchNormParamType<T>>(), epsilon));
135135
} else {
136+
// if MomentumTensor is set, use MomentumTensor value, momentum
137+
// is only used in this training branch
138+
if (ctx.HasInput("MomentumTensor")) {
139+
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
140+
Tensor mom_cpu;
141+
TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
142+
momentum = mom_cpu.data<float>()[0];
143+
}
144+
136145
// Run training mode.
137146
// obtain running mean and running inv var, and see if we need to
138147
// initialize them.

python/paddle/fluid/layers/nn.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4176,13 +4176,14 @@ def batch_norm(input,
41764176
sync_batch_norm automatically.
41774177

41784178
Args:
4179-
input(variable): The rank of input variable can be 2, 3, 4, 5. The data type
4179+
input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type
41804180
is float16 or float32 or float64.
41814181
act(string, Default None): Activation type, linear|relu|prelu|...
41824182
is_test (bool, Default False): A flag indicating whether it is in
41834183
test phrase or not.
4184-
momentum(float, Default 0.9): The value used for the moving_mean and
4185-
moving_var computation. The updated formula is:
4184+
momentum(float|Variable, Default 0.9): The value used for the moving_mean and
4185+
moving_var computation. This should be a float number or a Variable with
4186+
shape [1] and data type as float32. The updated formula is:
41864187
:math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)`
41874188
:math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)`
41884189
Default is 0.9.
@@ -4228,6 +4229,33 @@ def batch_norm(input,
42284229
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
42294230
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
42304231
hidden2 = fluid.layers.batch_norm(input=hidden1)
4232+
4233+
.. code-block:: python
4234+
4235+
# batch_norm with momentum as Variable
4236+
import paddle.fluid as fluid
4237+
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler
4238+
4239+
def get_decay_momentum(momentum_init, decay_steps, decay_rate):
4240+
global_step = lr_scheduler._decay_step_counter()
4241+
momentum = fluid.layers.create_global_var(
4242+
shape=[1],
4243+
value=float(momentum_init),
4244+
dtype='float32',
4245+
# set persistable for save checkpoints and resume
4246+
persistable=True,
4247+
name="momentum")
4248+
div_res = global_step / decay_steps
4249+
decayed_momentum = momentum_init * (decay_rate**div_res)
4250+
fluid.layers.assign(decayed_momentum, momentum)
4251+
4252+
return momentum
4253+
4254+
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
4255+
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
4256+
momentum = get_decay_momentum(0.9, 1e5, 0.9)
4257+
hidden2 = fluid.layers.batch_norm(input=hidden1, momentum=momentum)
4258+
42314259
"""
42324260
assert bias_attr is not False, "bias_attr should not be False in batch_norm."
42334261
helper = LayerHelper('batch_norm', **locals())
@@ -4303,31 +4331,36 @@ def batch_norm(input,
43034331
batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
43044332
dtype)
43054333

4334+
inputs = {
4335+
"X": input,
4336+
"Scale": scale,
4337+
"Bias": bias,
4338+
"Mean": mean,
4339+
"Variance": variance
4340+
}
4341+
attrs = {
4342+
"epsilon": epsilon,
4343+
"is_test": is_test,
4344+
"data_layout": data_layout,
4345+
"use_mkldnn": False,
4346+
"fuse_with_relu": False,
4347+
"use_global_stats": use_global_stats
4348+
}
4349+
if isinstance(momentum, Variable):
4350+
inputs['MomemtumTensor'] = momentum
4351+
else:
4352+
attrs['momentum'] = momentum
43064353
helper.append_op(
43074354
type="batch_norm",
4308-
inputs={
4309-
"X": input,
4310-
"Scale": scale,
4311-
"Bias": bias,
4312-
"Mean": mean,
4313-
"Variance": variance
4314-
},
4355+
inputs=inputs,
43154356
outputs={
43164357
"Y": batch_norm_out,
43174358
"MeanOut": mean_out,
43184359
"VarianceOut": variance_out,
43194360
"SavedMean": saved_mean,
43204361
"SavedVariance": saved_variance
43214362
},
4322-
attrs={
4323-
"momentum": momentum,
4324-
"epsilon": epsilon,
4325-
"is_test": is_test,
4326-
"data_layout": data_layout,
4327-
"use_mkldnn": False,
4328-
"fuse_with_relu": False,
4329-
"use_global_stats": use_global_stats
4330-
})
4363+
attrs=attrs)
43314364

43324365
return helper.append_activation(batch_norm_out)
43334366

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def setUp(self):
310310
self.fuse_with_relu = False
311311
self.data_formats = ["NCHW", "NHWC"]
312312
self.momentum = 0.9
313+
self.use_momentum_variable = False
313314
self.epsilon = 0.00001
314315
self.init_kernel_type()
315316
self.init_test_case()
@@ -367,6 +368,7 @@ def test_with_place(place, data_layout, shape):
367368
bias = np.random.random_sample(scale_shape).astype(np.float32)
368369
mean, variance = self.set_mean_variance(scale_shape, x, data_layout)
369370
y_grad = np.random.random_sample(shape).astype(np.float32)
371+
momentum_var = np.array([momentum]).astype(np.float32)
370372

371373
y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward(
372374
x, y_grad, scale, bias, mean, variance, epsilon, momentum,
@@ -380,7 +382,7 @@ def test_with_place(place, data_layout, shape):
380382

381383
var_names = [
382384
'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean',
383-
'saved_variance'
385+
'saved_variance', 'momentum_var'
384386
]
385387
ground_truth = {name: var_dict[name] for name in var_names}
386388

@@ -392,31 +394,36 @@ def test_with_place(place, data_layout, shape):
392394
name=name,
393395
dtype='float32',
394396
shape=ground_truth[name].shape)
397+
inputs = {
398+
"X": block.var('x'),
399+
"Scale": block.var('scale'),
400+
"Bias": block.var('bias'),
401+
"Mean": block.var('mean'),
402+
"Variance": block.var('variance')
403+
}
404+
attrs = {
405+
"epsilon": epsilon,
406+
"is_test": False,
407+
"data_layout": data_layout,
408+
"use_mkldnn": self.use_mkldnn,
409+
"fuse_with_relu": self.fuse_with_relu,
410+
"use_global_stats": self.use_global_stats
411+
}
412+
if self.use_momentum_variable:
413+
inputs['MomentumTensor'] = block.var('momentum_var')
414+
else:
415+
attrs['momentum'] = momentum
395416
bn_op = block.append_op(
396417
type="batch_norm",
397-
inputs={
398-
"X": block.var('x'),
399-
"Scale": block.var('scale'),
400-
"Bias": block.var('bias'),
401-
"Mean": block.var('mean'),
402-
"Variance": block.var('variance')
403-
},
418+
inputs=inputs,
404419
outputs={
405420
"Y": block.var('y'),
406421
"MeanOut": block.var('mean'), # share memory
407422
"VarianceOut": block.var('variance'), # share memory
408423
"SavedMean": block.var('saved_mean'),
409424
"SavedVariance": block.var('saved_variance')
410425
},
411-
attrs={
412-
"momentum": momentum,
413-
"epsilon": epsilon,
414-
"is_test": False,
415-
"data_layout": data_layout,
416-
"use_mkldnn": self.use_mkldnn,
417-
"fuse_with_relu": self.fuse_with_relu,
418-
"use_global_stats": self.use_global_stats
419-
})
426+
attrs=attrs)
420427
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
421428

422429
# generate backward op_desc
@@ -434,14 +441,15 @@ def test_with_place(place, data_layout, shape):
434441
grad_var.set_dtype(core.VarDesc.VarType.FP32)
435442

436443
exe = fluid.Executor(place)
437-
out = exe.run(
438-
program,
439-
feed={
440-
name: var_dict[name]
441-
for name in
442-
['x', 'scale', 'bias', 'mean', 'variance', 'y@GRAD']
443-
},
444-
fetch_list=self.fetch_list)
444+
out = exe.run(program,
445+
feed={
446+
name: var_dict[name]
447+
for name in [
448+
'x', 'scale', 'bias', 'mean', 'variance',
449+
'y@GRAD', 'momentum_var'
450+
]
451+
},
452+
fetch_list=self.fetch_list)
445453

446454
for id, name in enumerate(self.fetch_list):
447455
if name == 'variance':
@@ -471,6 +479,17 @@ def init_test_case(self):
471479
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
472480

473481

482+
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
483+
def init_test_case(self):
484+
self.use_momentum_variable = True
485+
self.use_global_stats = False
486+
self.no_grad_set = set()
487+
self.fetch_list = [
488+
'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD',
489+
'scale@GRAD', 'bias@GRAD'
490+
]
491+
492+
474493
class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
475494
def init_test_case(self):
476495
self.use_global_stats = True

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2465,6 +2465,19 @@ def make_batch_norm(self):
24652465
out = layers.batch_norm(data)
24662466
return (out)
24672467

2468+
def make_batch_norm_momentum_variable(self):
2469+
with program_guard(fluid.default_main_program(),
2470+
fluid.default_startup_program()):
2471+
data = self._get_data(
2472+
name='data', shape=[32, 128, 128], dtype="float32")
2473+
momentum = self._get_data(
2474+
name='momentum',
2475+
shape=[1],
2476+
dtype='float32',
2477+
append_batch_size=False)
2478+
out = layers.batch_norm(data, momentum=momentum)
2479+
return (out)
2480+
24682481
def make_range(self):
24692482
with program_guard(fluid.default_main_program(),
24702483
fluid.default_startup_program()):

0 commit comments

Comments
 (0)