Skip to content

Commit 151cfff

Browse files
committed
add more tests
1 parent 0a95a44 commit 151cfff

File tree

1 file changed

+80
-14
lines changed

1 file changed

+80
-14
lines changed

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

Lines changed: 80 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,27 @@ def __set_tensor__(name, data=None):
188188

189189
class TestBatchNormOpInference(OpTest):
190190
def setUp(self):
191+
self.op_type = "conv2d"
192+
self.is_test = True
191193
self.dtype = np.float32
194+
self.data_layout = "NCHW"
195+
init_dtype()
196+
init_data_layout()
197+
init_test_case()
192198

193-
def test_python(self):
194-
data_format = "NHWC"
195199
epsilon = 0.00001
196-
197-
n, h, w, c = 2, 3, 4, 5
198-
x_shape = [n, h, w, c]
200+
shape = self.shape
201+
if len(shape) == 2:
202+
x_shape = shape
203+
c = x_shape[1]
204+
else:
205+
n, h, w, c = shape[0], shape[1], shape[2], shape[3]
206+
if self.data_layout == "NHWC":
207+
x_shape = [n, h, w, c]
208+
elif self.data_layout == "NCHW":
209+
x_shape = [n, c, h, w]
210+
else:
211+
raise ValueError("Unknown data layout.")
199212
scale_shape = [c]
200213

201214
x_val = np.random.random_sample(x_shape).astype(self.dtype)
@@ -205,7 +218,64 @@ def test_python(self):
205218
mean = np.zeros(scale_shape).astype(self.dtype)
206219
variance = np.ones(scale_shape).astype(self.dtype)
207220

208-
# run forward
221+
saved_mean = np.zeros(scale_shape).astype(self.dtype)
222+
saved_variance = np.ones(scale_shape).astype(self.dtype)
223+
224+
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
225+
epsilon, self.data_layout).astype(self.dtype)
226+
227+
self.inputs = {
228+
'X': OpTest.np_dtype_to_fluid_dtype(x_val),
229+
'Scale': OpTest.np_dtype_to_fluid_dtype(scale_val),
230+
'Bias': OpTest.np_dtype_to_fluid_dtype(bias_val),
231+
'Mean': OpTest.np_dtype_to_fluid_dtype(mean),
232+
'Variance': OpTest.np_dtype_to_fluid_dtype(variance)
233+
}
234+
self.attrs = {
235+
'is_test': self.is_test,
236+
'epsilon': epsilon,
237+
'data_layout': self.data_layout
238+
}
239+
self.outputs = {
240+
'Y': y_out,
241+
'MeanOut': mean,
242+
'VarianceOut': variance,
243+
'SavedMean': saved_mean,
244+
'SavedVariance': saved_variance
245+
}
246+
247+
def test_check_output(self):
248+
self.check_output()
249+
250+
def init_dtype(self):
251+
pass
252+
253+
def init_data_layout(self):
254+
pass
255+
256+
def init_test_case(self):
257+
self.shape = [2, 3, 4, 5]
258+
259+
260+
class TestBatchNormOpTraining(OpTest):
261+
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
262+
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
263+
264+
def test_python_testing(self):
265+
data_format = "NHWC"
266+
epsilon = 0.00001
267+
268+
n, h, w, c = 2, 3, 4, 5
269+
x_shape = [n, h, w, c]
270+
scale_shape = [c]
271+
272+
x_val = np.random.random_sample(x_shape).astype(np.float32)
273+
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
274+
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
275+
276+
mean = np.zeros(scale_shape).astype(np.float32)
277+
variance = np.ones(scale_shape).astype(np.float32)
278+
209279
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
210280
epsilon, "NHWC")
211281

@@ -218,15 +288,11 @@ def test_python(self):
218288

219289
# transfer (N, C, H, W) back to (N, H, W, C)
220290
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
221-
self.__assert_close(y_out, y_out2_trans, "inference output")
291+
self.__assert_close(y_out, y_out2_trans,
292+
"inference outputs of two formats have differences")
222293
print 'python: NHWC, NCHW, inference checking passed'
223294

224-
225-
class TestBatchNormOpTraining(OpTest):
226-
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
227-
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
228-
229-
def test_python(self):
295+
def test_python_training(self):
230296
data_format = "NHWC"
231297
epsilon = 0.00001
232298
momentum = 0.9
@@ -264,7 +330,7 @@ def test_python(self):
264330

265331
# transfer (N, C, H, W) back to (N, H, W, C)
266332
y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
267-
self.__assert_close(y_out, y_out2_trans, "batch variance")
333+
self.__assert_close(y_out, y_out2_trans, "batch output")
268334
print 'python: NHWC, NCHW, forward checking passed'
269335

270336
# test backward now

0 commit comments

Comments
 (0)