Skip to content

Commit 57be5c6

Browse files
authored
"fix double type error" (#10322)
* "fix double type error" * "fix ci"
1 parent faebadd commit 57be5c6

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,13 @@ class BatchNormOp : public framework::OperatorWithKernel {
8787
const framework::ExecutionContext &ctx) const override {
8888
auto input_data_type =
8989
framework::ToDataType(ctx.Input<Tensor>("X")->type());
90-
// For float or float16 input tensor, the type of the scale, bias, mean,
91-
// and var tensors should both be float.
90+
// By default, the type of the scale, bias, mean,
91+
// and var tensors should both be float. (For float or float16 input tensor)
92+
// or double (For double input tensor).
9293
auto bn_param_type = framework::proto::VarType::FP32;
94+
if (input_data_type == framework::proto::VarType::FP64) {
95+
bn_param_type = framework::proto::VarType::FP64;
96+
}
9397
PADDLE_ENFORCE_EQ(bn_param_type,
9498
framework::ToDataType(ctx.Input<Tensor>("Scale")->type()),
9599
"Scale input should be of float type");
@@ -492,8 +496,9 @@ REGISTER_OPERATOR(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
492496
REGISTER_OPERATOR(batch_norm_grad, ops::BatchNormGradOp);
493497

494498
REGISTER_OP_CPU_KERNEL(
495-
batch_norm,
496-
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>);
499+
batch_norm, ops::BatchNormKernel<paddle::platform::CPUDeviceContext, float>,
500+
ops::BatchNormKernel<paddle::platform::CPUDeviceContext, double>);
497501
REGISTER_OP_CPU_KERNEL(
498502
batch_norm_grad,
499-
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>);
503+
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, float>,
504+
ops::BatchNormGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ namespace ops = paddle::operators;
287287
namespace plat = paddle::platform;
288288
REGISTER_OP_CUDA_KERNEL(
289289
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
290+
ops::BatchNormKernel<plat::CUDADeviceContext, double>,
290291
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
291292
REGISTER_OP_CUDA_KERNEL(
292-
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);
293+
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>,
294+
ops::BatchNormGradKernel<plat::CUDADeviceContext, double>);

paddle/fluid/operators/mul_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker,
204204
paddle::framework::DefaultGradOpDescMaker<true>);
205205
REGISTER_OPERATOR(mul_grad, ops::MulGradOp);
206206
REGISTER_OP_CPU_KERNEL(
207-
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>);
207+
mul, ops::MulKernel<paddle::platform::CPUDeviceContext, float>,
208+
ops::MulKernel<paddle::platform::CPUDeviceContext, double>);
208209
REGISTER_OP_CPU_KERNEL(
209-
mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>);
210+
mul_grad, ops::MulGradKernel<paddle::platform::CPUDeviceContext, float>,
211+
ops::MulGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/mul_op.cu.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License. */
1818
namespace ops = paddle::operators;
1919
namespace plat = paddle::platform;
2020
REGISTER_OP_CUDA_KERNEL(mul, ops::MulKernel<plat::CUDADeviceContext, float>,
21+
ops::MulKernel<plat::CUDADeviceContext, double>,
2122
ops::MulKernel<plat::CUDADeviceContext, plat::float16>);
2223
REGISTER_OP_CUDA_KERNEL(mul_grad,
23-
ops::MulGradKernel<plat::CUDADeviceContext, float>);
24+
ops::MulGradKernel<plat::CUDADeviceContext, float>,
25+
ops::MulGradKernel<plat::CUDADeviceContext, double>);

0 commit comments

Comments
 (0)