Skip to content

Commit a574586

Browse files
[cherry-pick 2.3] fix bug of batch_norm_grad kernel with fp16 (#42461)
* fix bug of batch_norm_grad kernel with fp16 * format code
1 parent 87e6149 commit a574586

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

paddle/phi/kernels/gpu/batch_norm_grad_kernel.cu

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -987,10 +987,9 @@ PD_REGISTER_KERNEL(batch_norm_grad,
987987
double,
988988
phi::dtype::float16) {
989989
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
990-
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
991-
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
992-
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
993-
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
990+
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
991+
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
992+
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
994993
}
995994
}
996995

@@ -1002,10 +1001,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
10021001
double,
10031002
phi::dtype::float16) {
10041003
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
1005-
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
1006-
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
1007-
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
1008-
kernel->OutputAt(4).SetDataType(phi::DataType::FLOAT32);
1004+
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); // x_grad
1005+
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32); // scale_grad
1006+
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); // bias_grad
10091007
}
10101008
}
10111009

@@ -1018,7 +1016,6 @@ PD_REGISTER_KERNEL(batch_norm_grad_grad,
10181016
phi::BatchNormDoubleGradKernel,
10191017
float,
10201018
double) {}
1021-
10221019
#else
10231020
PD_REGISTER_KERNEL(batch_norm_grad_grad,
10241021
GPU,

0 commit comments

Comments
 (0)