Skip to content

Commit 0a95a44

Browse files
committed
add python batch norm inference test
1 parent 39c676e commit 0a95a44

File tree

2 files changed

+70
-3
lines changed

2 files changed

+70
-3
lines changed

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
125125

126126
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
127127
math::SetConstant<platform::CUDADeviceContext, T> functor;
128-
functor(dev_ctx, saved_mean, 0);
129-
functor(dev_ctx, saved_variance, 0);
128+
functor(dev_ctx, saved_mean, static_cast<T>(0));
129+
functor(dev_ctx, saved_variance, static_cast<T>(0));
130130

131131
auto handle = dev_ctx.cudnn_handle();
132132

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,37 @@ def get_backward_op(scope, op, no_grad_set):
3131
return backward_op
3232

3333

34+
def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
35+
x_shape = x.shape
36+
if len(x_shape) == 2:
37+
if data_format == "NCHW":
38+
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
39+
else:
40+
x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
41+
42+
if data_format == "NCHW":
43+
n, c, h, w = x.shape
44+
mean_tile = np.reshape(mean, (1, c, 1, 1))
45+
mean_tile = np.tile(mean_tile, (n, 1, h, w))
46+
var_tile = np.reshape(var, (1, c, 1, 1))
47+
var_tile = np.tile(var_tile, (n, 1, h, w))
48+
normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
49+
scale_tile = np.reshape(scale, (1, c, 1, 1))
50+
scale_tile = np.tile(scale_tile, (n, 1, h, w))
51+
offset_tile = np.reshape(offset, (1, c, 1, 1))
52+
offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
53+
y = normalized * scale_tile + offset_tile
54+
elif data_format == "NHWC":
55+
normalized = (x - mean) / np.sqrt(var + epsilon)
56+
y = normalized * scale + offset
57+
else:
58+
raise ValueError("Unknown data order.")
59+
60+
if len(x_shape) == 2:
61+
y = np.reshape(y, x_shape)
62+
return y
63+
64+
3465
def _reference_training(x, scale, offset, epsilon, data_format):
3566
x_shape = x.shape
3667
if len(x_shape) == 2:
@@ -155,7 +186,43 @@ def __set_tensor__(name, data=None):
155186
__set_tensor__(output, data)
156187

157188

158-
class TestBatchNormOp(OpTest):
189+
class TestBatchNormOpInference(OpTest):
190+
def setUp(self):
191+
self.dtype = np.float32
192+
193+
def test_python(self):
194+
data_format = "NHWC"
195+
epsilon = 0.00001
196+
197+
n, h, w, c = 2, 3, 4, 5
198+
x_shape = [n, h, w, c]
199+
scale_shape = [c]
200+
201+
x_val = np.random.random_sample(x_shape).astype(self.dtype)
202+
scale_val = np.random.random_sample(scale_shape).astype(self.dtype)
203+
bias_val = np.random.random_sample(scale_shape).astype(self.dtype)
204+
205+
mean = np.zeros(scale_shape).astype(self.dtype)
206+
variance = np.ones(scale_shape).astype(self.dtype)
207+
208+
# run forward
209+
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
210+
epsilon, "NHWC")
211+
212+
# running N, C, H, W case
213+
# should produce the same results
214+
x_shape2 = [n, c, h, w]
215+
x_val2 = np.transpose(x_val, (0, 3, 1, 2))
216+
y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
217+
epsilon, "NCHW")
218+
219+
# transfer (N, C, H, W) back to (N, H, W, C)
220+
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
221+
self.__assert_close(y_out, y_out2_trans, "inference output")
222+
print 'python: NHWC, NCHW, inference checking passed'
223+
224+
225+
class TestBatchNormOpTraining(OpTest):
159226
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
160227
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
161228

0 commit comments

Comments
 (0)