Skip to content

Commit c975fe1

Browse files
authored
batch norm support matrix input (#5980)
* batch norm support matrix input * update gpu code * format code
1 parent 23b3fef commit c975fe1

File tree

5 files changed

+93
-44
lines changed

5 files changed

+93
-44
lines changed

paddle/operators/batch_norm_op.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,14 @@ class BatchNormOp : public framework::OperatorWithKernel {
6262
const auto x_dims = ctx->GetInputDim("X");
6363
const TensorFormat tensor_format =
6464
StringToTensorFormat(ctx->Attrs().Get<std::string>("tensor_format"));
65+
66+
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
67+
"Input X must have 2 to 5 dimensions.");
68+
6569
const int C =
6670
(tensor_format == TensorFormat::NCHW ? x_dims[1]
6771
: x_dims[x_dims.size() - 1]);
6872

69-
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
70-
"Input X must have 3 to 5 dimensions.");
71-
7273
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
7374
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
7475
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
@@ -146,8 +147,8 @@ class BatchNormKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
146147

147148
const auto *x = ctx.Input<Tensor>("X");
148149
const auto &x_dims = x->dims();
149-
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
150-
"The Input dim size should be between 3 and 5");
150+
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
151+
"The Input dim size should be between 2 and 5");
151152
const int N = x_dims[0];
152153
const int C =
153154
(tensor_format == TensorFormat::NCHW ? x_dims[1]
@@ -339,8 +340,8 @@ class BatchNormGradKernel<platform::CPUPlace, T>
339340
// Get the size for each dimension.
340341
// NCHW [batch_size, in_channels, in_height, in_width]
341342
const auto &x_dims = x->dims();
342-
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
343-
"The Input dim size should be between 3 and 5");
343+
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
344+
"The Input dim size should be between 2 and 5");
344345
const int N = x_dims[0];
345346
const int C =
346347
(tensor_format == TensorFormat::NCHW ? x_dims[1]

paddle/operators/batch_norm_op.cu.cc

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,21 @@ void ExtractNCWHD(const framework::DDim &dims,
2929
const TensorFormat &tensor_format, int *N, int *C, int *H,
3030
int *W, int *D) {
3131
*N = dims[0];
32-
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1];
33-
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1];
34-
*W = dims.size() > 3
35-
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2])
36-
: 1;
37-
*D = dims.size() > 4
38-
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3])
39-
: 1;
32+
if (dims.size() == 2) {
33+
*C = dims[1];
34+
*H = 1;
35+
*W = 1;
36+
*D = 1;
37+
} else {
38+
*C = tensor_format == TensorFormat::NCHW ? dims[1] : dims[dims.size() - 1];
39+
*H = tensor_format == TensorFormat::NCHW ? dims[2] : dims[1];
40+
*W = dims.size() > 3
41+
? (tensor_format == TensorFormat::NCHW ? dims[3] : dims[2])
42+
: 1;
43+
*D = dims.size() > 4
44+
? (tensor_format == TensorFormat::NCHW ? dims[4] : dims[3])
45+
: 1;
46+
}
4047
}
4148

4249
template <typename T>
@@ -56,8 +63,8 @@ class BatchNormKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
5663
// NCHW [batch_size, in_channels, in_height, in_width]
5764
const auto *x = ctx.Input<Tensor>("X");
5865
const auto &x_dims = x->dims();
59-
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
60-
"The Input dim size should be between 3 and 5");
66+
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
67+
"The Input dim size should be between 2 and 5");
6168
int N, C, H, W, D;
6269
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
6370

@@ -180,8 +187,8 @@ class BatchNormGradKernel<platform::GPUPlace, T>
180187

181188
const auto &x_dims = x->dims();
182189

183-
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
184-
"The Input dim size should be between 3 and 5");
190+
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
191+
"The Input dim size should be between 2 and 5");
185192
int N, C, H, W, D;
186193
ExtractNCWHD(x_dims, tensor_format, &N, &C, &H, &W, &D);
187194

python/paddle/v2/fluid/tests/book/test_image_classification_train.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,7 @@ def conv_block(input, num_filter, groups, dropouts):
6969

7070
drop = fluid.layers.dropout(x=conv5, dropout_prob=0.5)
7171
fc1 = fluid.layers.fc(input=drop, size=512, act=None)
72-
reshape1 = fluid.layers.reshape(x=fc1, shape=list(fc1.shape + (1, 1)))
73-
bn = fluid.layers.batch_norm(input=reshape1, act='relu')
72+
bn = fluid.layers.batch_norm(input=fc1, act='relu')
7473
drop2 = fluid.layers.dropout(x=bn, dropout_prob=0.5)
7574
fc2 = fluid.layers.fc(input=drop2, size=512, act=None)
7675
return fc2

python/paddle/v2/fluid/tests/test_batch_norm_op.py

Lines changed: 47 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ def get_backward_op(scope, op, no_grad_set):
2121

2222

2323
def _reference_training(x, scale, offset, epsilon, data_format):
24+
x_shape = x.shape
25+
if len(x_shape) == 2:
26+
if data_format == "NCHW":
27+
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
28+
else:
29+
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
30+
2431
if data_format == "NCHW":
2532
n, c, h, w = x.shape
2633
x_square = x * x
@@ -39,6 +46,8 @@ def _reference_training(x, scale, offset, epsilon, data_format):
3946
offset_tile = np.reshape(offset, (1, c, 1, 1))
4047
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
4148
y = normalized * scale_tile + offset_tile
49+
if len(x_shape) == 2:
50+
y = np.reshape(y, (y.shape[0], y.shape[1]))
4251
return y, mean, var
4352
elif data_format == "NHWC":
4453
x_square = x * x
@@ -48,7 +57,10 @@ def _reference_training(x, scale, offset, epsilon, data_format):
4857
mean = x_sum / element_count
4958
var = x_square_sum / element_count - mean * mean
5059
normalized = (x - mean) / np.sqrt(var + epsilon)
51-
return (normalized * scale + offset), mean, var
60+
y = normalized * scale + offset
61+
if len(x_shape) == 2:
62+
y = np.reshape(y, x_shape)
63+
return y, mean, var
5264
else:
5365
raise ValueError("Unknown data order.")
5466

@@ -65,6 +77,18 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
6577
# (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
6678

6779
# transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
80+
x_shape = x.shape
81+
82+
if len(x_shape) == 2:
83+
if data_format == "NCHW":
84+
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
85+
grad_y = np.reshape(grad_y,
86+
(grad_y.shape[0], grad_y.shape[1], 1, 1))
87+
else:
88+
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
89+
grad_y = np.reshape(grad_y,
90+
(grad_y.shape[0], 1, 1, grad_y.shape[1]))
91+
6892
if data_format == "NCHW":
6993
x = np.transpose(x, (0, 2, 3, 1))
7094
grad_y = np.transpose(grad_y, (0, 2, 3, 1))
@@ -83,6 +107,9 @@ def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
83107
grad_x = np.transpose(grad_x, (0, 3, 1, 2))
84108
x = np.transpose(x, (0, 3, 1, 2))
85109
grad_y = np.transpose(grad_y, (0, 3, 1, 2))
110+
111+
if len(x_shape) == 2:
112+
grad_x = np.reshape(grad_x, x_shape)
86113
return grad_x, grad_scale, grad_offset
87114

88115

@@ -127,7 +154,7 @@ def test_python(self):
127154
momentum = 0.9
128155

129156
# N, H, W, C: 2, 3, 4, 2
130-
n, h, w, c = 2, 3, 4, 2
157+
n, h, w, c = 2, 3, 4, 5
131158
x_shape = [n, h, w, c]
132159
scale_shape = [c]
133160

@@ -184,20 +211,23 @@ def test_python(self):
184211
print 'python: NHWC, NCHW, backward checking passed'
185212

186213
def test_forward_backward(self):
187-
def test_with_place(place, tensor_format):
214+
def test_with_place(place, tensor_format, shape):
188215
# attr
189216
epsilon = 0.00001
190217
momentum = 0.9
191218

192-
# N, H, W, C: 12, 3, 4, 2
193-
n, h, w, c = 2, 3, 4, 2
194-
195-
if data_format == "NHWC":
196-
x_shape = [n, h, w, c]
197-
elif data_format == "NCHW":
198-
x_shape = [n, c, h, w]
219+
if len(shape) == 2:
220+
x_shape = shape
221+
c = shape[1]
199222
else:
200-
raise ValueError("Unknown data type.")
223+
# n, h, w, c = 2, 3, 4, 2
224+
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
225+
if data_format == "NHWC":
226+
x_shape = [n, h, w, c]
227+
elif data_format == "NCHW":
228+
x_shape = [n, c, h, w]
229+
else:
230+
raise ValueError("Unknown data type.")
201231
scale_shape = [c]
202232

203233
x_val = np.random.random_sample(x_shape).astype(np.float32)
@@ -219,7 +249,10 @@ def test_with_place(place, tensor_format):
219249
# for gradient test
220250
# y_grad = np.ones(x_shape).astype(np.float32)
221251
y_grad = np.zeros(x_shape).astype(np.float32)
222-
y_grad[0, 0, 0, 0] = 1.
252+
if len(y_grad.shape) == 2:
253+
y_grad[0, 0] = 1.
254+
else:
255+
y_grad[0, 0, 0, 0] = 1.
223256
# y_grad = np.random.random_sample(x_shape).astype(np.float32)
224257
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
225258
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon,
@@ -313,7 +346,8 @@ def test_with_place(place, tensor_format):
313346
places.append(core.GPUPlace(0))
314347
for place in places:
315348
for data_format in ["NCHW", "NHWC"]:
316-
test_with_place(place, data_format)
349+
test_with_place(place, data_format, [2, 3, 4, 5])
350+
test_with_place(place, data_format, [2, 3])
317351

318352

319353
if __name__ == '__main__':

python/paddle/v2/fluid/tests/test_image_classification_layer.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22

3-
import paddle.v2.fluid.layers as layers
3+
import paddle.v2.fluid as fluid
44
import paddle.v2.fluid.nets as nets
55
from paddle.v2.fluid.framework import Program
66

@@ -29,27 +29,35 @@ class TestLayer(unittest.TestCase):
2929
def test_batch_norm_layer(self):
3030
main_program = Program()
3131
startup_program = Program()
32-
images = layers.data(
32+
images = fluid.layers.data(
3333
name='pixel',
3434
shape=[3, 48, 48],
3535
dtype='float32',
3636
main_program=main_program)
37-
layers.batch_norm(
37+
hidden1 = fluid.layers.batch_norm(
3838
input=images,
3939
main_program=main_program,
4040
startup_program=startup_program)
41+
hidden2 = fluid.layers.fc(input=hidden1,
42+
size=128,
43+
act='relu',
44+
main_program=main_program)
45+
hidden3 = fluid.layers.batch_norm(
46+
input=hidden2,
47+
main_program=main_program,
48+
startup_program=startup_program)
4149

42-
# print str(main_program)
50+
print str(main_program)
4351

4452
def test_dropout_layer(self):
4553
main_program = Program()
4654
startup_program = Program()
47-
images = layers.data(
55+
images = fluid.layers.data(
4856
name='pixel',
4957
shape=[3, 48, 48],
5058
dtype='float32',
5159
main_program=main_program)
52-
layers.dropout(
60+
fluid.layers.dropout(
5361
x=images,
5462
dropout_prob=0.5,
5563
main_program=main_program,
@@ -61,7 +69,7 @@ def test_img_conv_group(self):
6169
main_program = Program()
6270
startup_program = Program()
6371

64-
images = layers.data(
72+
images = fluid.layers.data(
6573
name='pixel',
6674
shape=[3, 48, 48],
6775
dtype='float32',
@@ -77,19 +85,19 @@ def test_img_conv_group(self):
7785
def test_elementwise_add_with_act(self):
7886
main_program = Program()
7987
startup_program = Program()
80-
image1 = layers.data(
88+
image1 = fluid.layers.data(
8189
name='pixel1',
8290
shape=[3, 48, 48],
8391
dtype='float32',
8492
main_program=main_program,
8593
startup_program=startup_program)
86-
image2 = layers.data(
94+
image2 = fluid.layers.data(
8795
name='pixel2',
8896
shape=[3, 48, 48],
8997
dtype='float32',
9098
main_program=main_program,
9199
startup_program=startup_program)
92-
out = layers.elementwise_add(
100+
out = fluid.layers.elementwise_add(
93101
x=image1,
94102
y=image2,
95103
act='relu',

0 commit comments

Comments
 (0)