@@ -987,10 +987,9 @@ PD_REGISTER_KERNEL(batch_norm_grad,
987
987
double ,
988
988
phi::dtype::float16) {
989
989
if (kernel_key.dtype () == phi::DataType::FLOAT16) {
990
- kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
991
- kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
992
- kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
993
- kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
990
+ kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT32); // x_grad
991
+ kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32); // scale_grad
992
+ kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32); // bias_grad
994
993
}
995
994
}
996
995
@@ -1002,10 +1001,9 @@ PD_REGISTER_KERNEL(batch_norm_grad_raw,
1002
1001
double ,
1003
1002
phi::dtype::float16) {
1004
1003
if (kernel_key.dtype () == phi::DataType::FLOAT16) {
1005
- kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32);
1006
- kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32);
1007
- kernel->OutputAt (3 ).SetDataType (phi::DataType::FLOAT32);
1008
- kernel->OutputAt (4 ).SetDataType (phi::DataType::FLOAT32);
1004
+ kernel->OutputAt (0 ).SetDataType (phi::DataType::FLOAT32); // x_grad
1005
+ kernel->OutputAt (1 ).SetDataType (phi::DataType::FLOAT32); // scale_grad
1006
+ kernel->OutputAt (2 ).SetDataType (phi::DataType::FLOAT32); // bias_grad
1009
1007
}
1010
1008
}
1011
1009
@@ -1018,7 +1016,6 @@ PD_REGISTER_KERNEL(batch_norm_grad_grad,
1018
1016
phi::BatchNormDoubleGradKernel,
1019
1017
float ,
1020
1018
double ) {}
1021
-
1022
1019
#else
1023
1020
PD_REGISTER_KERNEL (batch_norm_grad_grad,
1024
1021
GPU,
0 commit comments