Skip to content

Commit a28dffb

Browse files
authored
Fix/adam float64 (#10407)
* "optimizer op support float64" * "fix ci" * "fix ftrl op"
1 parent 6418c42 commit a28dffb

File tree

8 files changed

+56
-0
lines changed

8 files changed

+56
-0
lines changed

paddle/fluid/operators/adadelta_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
2021
class AdadeltaOp : public framework::OperatorWithKernel {
2122
public:
2223
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel {
5556
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
5657
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
5758
}
59+
framework::OpKernelType GetExpectedKernelType(
60+
const framework::ExecutionContext &ctx) const override {
61+
auto input_data_type =
62+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
63+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
64+
}
5865
};
5966

6067
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/adagrad_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ limitations under the License. */
2323
namespace paddle {
2424
namespace operators {
2525

26+
using Tensor = framework::Tensor;
2627
class AdagradOp : public framework::OperatorWithKernel {
2728
public:
2829
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel {
5657
ctx->SetOutputDim("ParamOut", param_dims);
5758
ctx->SetOutputDim("MomentOut", param_dims);
5859
}
60+
framework::OpKernelType GetExpectedKernelType(
61+
const framework::ExecutionContext& ctx) const override {
62+
auto input_data_type =
63+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
64+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
65+
}
5966
};
6067

6168
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/adam_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
2021
class AdamOp : public framework::OperatorWithKernel {
2122
public:
2223
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel {
6970
ctx->SetOutputDim("Moment1Out", param_dims);
7071
ctx->SetOutputDim("Moment2Out", param_dims);
7172
}
73+
framework::OpKernelType GetExpectedKernelType(
74+
const framework::ExecutionContext &ctx) const override {
75+
auto input_data_type =
76+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
77+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
78+
}
7279
};
7380

7481
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/adamax_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
2021
class AdamaxOp : public framework::OperatorWithKernel {
2122
public:
2223
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel {
6364
ctx->SetOutputDim("MomentOut", param_dims);
6465
ctx->SetOutputDim("InfNormOut", param_dims);
6566
}
67+
framework::OpKernelType GetExpectedKernelType(
68+
const framework::ExecutionContext &ctx) const override {
69+
auto input_data_type =
70+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
71+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
72+
}
6673
};
6774

6875
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/decayed_adagrad_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
2021
class DecayedAdagradOp : public framework::OperatorWithKernel {
2122
public:
2223
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
5152
ctx->SetOutputDim("ParamOut", param_dims);
5253
ctx->SetOutputDim("MomentOut", param_dims);
5354
}
55+
framework::OpKernelType GetExpectedKernelType(
56+
const framework::ExecutionContext &ctx) const override {
57+
auto input_data_type =
58+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
59+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
60+
}
5461
};
5562

5663
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/ftrl_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
2021
class FTRLOp : public framework::OperatorWithKernel {
2122
public:
2223
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel {
5354
ctx->SetOutputDim("SquaredAccumOut", param_dim);
5455
ctx->SetOutputDim("LinearAccumOut", param_dim);
5556
}
57+
framework::OpKernelType GetExpectedKernelType(
58+
const framework::ExecutionContext &ctx) const override {
59+
auto input_data_type =
60+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
61+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
62+
}
5663
};
5764

5865
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/proximal_adagrad_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
2021
class ProximalAdagradOp : public framework::OperatorWithKernel {
2122
public:
2223
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
5556
ctx->SetOutputDim("ParamOut", param_dim);
5657
ctx->SetOutputDim("MomentOut", param_dim);
5758
}
59+
framework::OpKernelType GetExpectedKernelType(
60+
const framework::ExecutionContext &ctx) const override {
61+
auto input_data_type =
62+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
63+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
64+
}
5865
};
5966

6067
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {

paddle/fluid/operators/proximal_gd_op.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using Tensor = framework::Tensor;
2021
class ProximalGDOp : public framework::OperatorWithKernel {
2122
public:
2223
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel {
4344

4445
ctx->SetOutputDim("ParamOut", param_dim);
4546
}
47+
framework::OpKernelType GetExpectedKernelType(
48+
const framework::ExecutionContext &ctx) const override {
49+
auto input_data_type =
50+
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
51+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
52+
}
4653
};
4754

4855
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {

0 commit comments

Comments
 (0)