Skip to content

Commit ad91bfe

Browse files
luotao1qingqing01
authored andcommitted
fix a bug in test_batch_norm_op.py (#10094)
1 parent eb8e14c commit ad91bfe

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def _reference_grad(x, y_grad, scale, mean, var, epsilon, data_format):
100100
# (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
101101

102102
# transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
103+
if data_format != "NCHW" and data_format != "NHWC":
104+
raise ValueError("Unknown data order.")
105+
103106
if data_format == "NCHW":
104107
x = np.transpose(x, (0, 2, 3, 1))
105108
y_grad = np.transpose(y_grad, (0, 2, 3, 1))
@@ -304,7 +307,7 @@ def test_with_place(place, data_layout, shape):
304307
# run backward
305308
y_grad = np.random.random_sample(shape).astype(np.float32)
306309
x_grad, scale_grad, bias_grad = _reference_grad(
307-
x, y_grad, scale, saved_mean, var_ref, epsilon, data_format)
310+
x, y_grad, scale, saved_mean, var_ref, epsilon, data_layout)
308311

309312
var_dict = locals()
310313
var_dict['y@GRAD'] = y_grad

0 commit comments

Comments
 (0)