Skip to content

Commit f63ff90

Browse files
authored
Fix/fp64 (#10346)
* "fix double type error" * "fix ci" * "softmax fp64" * "fix momentum" * "fix ci"
1 parent 1ae086e commit f63ff90

File tree

6 files changed

+24
-12
lines changed

6 files changed

+24
-12
lines changed

paddle/fluid/operators/momentum_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
21+
2022
class MomentumOp : public framework::OperatorWithKernel {
2123
public:
2224
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -50,6 +52,12 @@ class MomentumOp : public framework::OperatorWithKernel {
5052
ctx->SetOutputDim("ParamOut", param_dim);
5153
ctx->SetOutputDim("VelocityOut", param_dim);
5254
}
55+
framework::OpKernelType GetExpectedKernelType(
56+
const framework::ExecutionContext &ctx) const override {
57+
auto input_data_type =
58+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
59+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
60+
}
5361
};
5462

5563
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/scale_op.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel {
3535
}
3636
};
3737

38-
template <typename AttrType>
3938
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
4039
public:
4140
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
@@ -47,9 +46,9 @@ Scale operator
4746
4847
$$Out = scale*X$$
4948
)DOC");
50-
AddAttr<AttrType>("scale",
51-
"(float, default 1.0)"
52-
"The scaling factor of the scale operator.")
49+
AddAttr<float>("scale",
50+
"(float, default 1.0)"
51+
"The scaling factor of the scale operator.")
5352
.SetDefault(1.0);
5453
}
5554
};
@@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
7372

7473
namespace ops = paddle::operators;
7574

76-
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
77-
ops::ScaleGradMaker);
75+
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker);
7876
REGISTER_OP_CPU_KERNEL(
7977
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
8078
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,

paddle/fluid/operators/softmax_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
164164
paddle::framework::DefaultGradOpDescMaker<true>);
165165
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
166166
REGISTER_OP_CPU_KERNEL(
167-
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>);
167+
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
168+
ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
168169
REGISTER_OP_CPU_KERNEL(
169170
softmax_grad,
170-
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>);
171+
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
172+
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/softmax_op.cu.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ namespace ops = paddle::operators;
1919
namespace plat = paddle::platform;
2020
REGISTER_OP_CUDA_KERNEL(
2121
softmax, ops::SoftmaxKernel<plat::CUDADeviceContext, float>,
22+
ops::SoftmaxKernel<plat::CUDADeviceContext, double>,
2223
ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>);
23-
REGISTER_OP_CUDA_KERNEL(softmax_grad,
24-
ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>);
24+
REGISTER_OP_CUDA_KERNEL(
25+
softmax_grad, ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>,
26+
ops::SoftmaxGradKernel<plat::CUDADeviceContext, double>);

paddle/fluid/operators/top_k_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,5 @@ namespace ops = paddle::operators;
7575
REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker,
7676
paddle::framework::EmptyGradOpMaker);
7777
REGISTER_OP_CPU_KERNEL(top_k,
78-
ops::TopkKernel<paddle::platform::CPUPlace, float>);
78+
ops::TopkKernel<paddle::platform::CPUPlace, float>,
79+
ops::TopkKernel<paddle::platform::CPUPlace, double>);

paddle/fluid/operators/top_k_op.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,4 +318,5 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
318318
} // namespace operators
319319
} // namespace paddle
320320

321-
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>);
321+
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>,
322+
paddle::operators::TopkOpCUDAKernel<double>);

0 commit comments

Comments
 (0)