@@ -12,41 +12,125 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
+ #include " paddle/fluid/operators/optimizers/adam_op.h"
15
16
#include " paddle/fluid/framework/op_version_registry.h"
16
-
17
- #include " paddle/fluid/framework/infershape_utils.h"
18
- #include " paddle/fluid/framework/op_registry.h"
19
- #include " paddle/phi/core/infermeta_utils.h"
20
- #include " paddle/phi/infermeta/multiary.h"
17
+ #include " paddle/fluid/operators/optimizers/adamw_op.h"
21
18
22
19
namespace paddle {
23
20
namespace operators {
24
21
25
22
using Tensor = framework::Tensor;
26
23
27
- class AdamOp : public framework ::OperatorWithKernel {
28
- public:
29
- using framework::OperatorWithKernel::OperatorWithKernel;
24
+ void AdamOp::InferShape (framework::InferShapeContext *ctx) const {
25
+ PADDLE_ENFORCE_EQ (
26
+ ctx->HasInput (" Param" ), true ,
27
+ platform::errors::NotFound (" Input(Param) of AdamOp should not be null." ));
28
+ PADDLE_ENFORCE_EQ (
29
+ ctx->HasInput (" Grad" ), true ,
30
+ platform::errors::NotFound (" Input(Grad) of AdamOp should not be null." ));
31
+ PADDLE_ENFORCE_EQ (ctx->HasInput (" Moment1" ), true ,
32
+ platform::errors::NotFound (
33
+ " Input(Moment1) of AdamOp should not be null." ));
34
+ PADDLE_ENFORCE_EQ (ctx->HasInput (" Moment2" ), true ,
35
+ platform::errors::NotFound (
36
+ " Input(Moment2) of AdamOp should not be null." ));
37
+ PADDLE_ENFORCE_EQ (ctx->HasInput (" LearningRate" ), true ,
38
+ platform::errors::NotFound (
39
+ " Input(LearningRate) of AdamOp should not be null." ));
40
+ PADDLE_ENFORCE_EQ (ctx->HasInput (" Beta1Pow" ), true ,
41
+ platform::errors::NotFound (
42
+ " Input(Beta1Pow) of AdamOp should not be null." ));
43
+ PADDLE_ENFORCE_EQ (ctx->HasInput (" Beta2Pow" ), true ,
44
+ platform::errors::NotFound (
45
+ " Input(Beta2Pow) of AdamOp should not be null." ));
46
+
47
+ PADDLE_ENFORCE_EQ (ctx->HasOutput (" ParamOut" ), true ,
48
+ platform::errors::NotFound (
49
+ " Output(ParamOut) of AdamOp should not be null." ));
50
+ PADDLE_ENFORCE_EQ (ctx->HasOutput (" Moment1Out" ), true ,
51
+ platform::errors::NotFound (
52
+ " Output(Moment1Out) of AdamOp should not be null." ));
53
+ PADDLE_ENFORCE_EQ (ctx->HasOutput (" Moment2Out" ), true ,
54
+ platform::errors::NotFound (
55
+ " Output(Moment2Out) of AdamOp should not be null." ));
30
56
31
- framework::OpKernelType GetExpectedKernelType (
32
- const framework::ExecutionContext &ctx) const {
33
- auto input_data_type =
34
- OperatorWithKernel::IndicateVarDataType (ctx, " Param" );
35
- return framework::OpKernelType (input_data_type, ctx.GetPlace ());
57
+ auto lr_dims = ctx->GetInputDim (" LearningRate" );
58
+ PADDLE_ENFORCE_NE (
59
+ phi::product (lr_dims), 0 ,
60
+ platform::errors::InvalidArgument (
61
+ " The number of LearningRate shall not be 0, but received %d. Maybe "
62
+ " the Input variable LearningRate has not "
63
+ " been initialized. You may need to confirm "
64
+ " if you put exe.run(startup_program) "
65
+ " after optimizer.minimize function." ,
66
+ phi::product (lr_dims)));
67
+ PADDLE_ENFORCE_EQ (
68
+ phi::product (lr_dims), 1 ,
69
+ platform::errors::InvalidArgument (
70
+ " Learning rate should have 1 dimension, but received %d" ,
71
+ phi::product (lr_dims)));
72
+ auto beta1_pow_dims = ctx->GetInputDim (" Beta1Pow" );
73
+ VLOG (3 ) << " dims of Beta1Pow : [" << beta1_pow_dims << " ]" ;
74
+ PADDLE_ENFORCE_GE (phi::product (beta1_pow_dims), 1 ,
75
+ platform::errors::InvalidArgument (
76
+ " The size of Beta1 power accumulator should be greater "
77
+ " than 0, but received %d." ,
78
+ phi::product (beta1_pow_dims)));
79
+ auto beta2_pow_dims = ctx->GetInputDim (" Beta2Pow" );
80
+ VLOG (3 ) << " dims of Beta2Pow : [" << beta2_pow_dims << " ]" ;
81
+ PADDLE_ENFORCE_GE (phi::product (beta2_pow_dims), 1 ,
82
+ platform::errors::InvalidArgument (
83
+ " The size of Beta2 power accumulator should be greater "
84
+ " than 0, but received %d." ,
85
+ phi::product (beta2_pow_dims)));
86
+
87
+ auto param_dims = ctx->GetInputDim (" Param" );
88
+ if (ctx->GetInputsVarType (" Grad" )[0 ] ==
89
+ framework::proto::VarType::LOD_TENSOR) {
90
+ PADDLE_ENFORCE_EQ (
91
+ param_dims, ctx->GetInputDim (" Grad" ),
92
+ platform::errors::InvalidArgument (
93
+ " Param and Grad input of AdamOp should have same dimension. But "
94
+ " received Param dims: [%s], Grad dims: [%s]." ,
95
+ param_dims, ctx->GetInputDim (" Grad" )));
36
96
}
97
+ PADDLE_ENFORCE_EQ (
98
+ param_dims, ctx->GetInputDim (" Moment1" ),
99
+ platform::errors::InvalidArgument (
100
+ " Param and Moment1 input of AdamOp should have same dimension. But "
101
+ " received Param dims: [%s], Moment1 dims: [%s]." ,
102
+ param_dims, ctx->GetInputDim (" Moment1" )));
103
+ PADDLE_ENFORCE_EQ (
104
+ param_dims, ctx->GetInputDim (" Moment2" ),
105
+ platform::errors::InvalidArgument (
106
+ " Param and Moment2 input of AdamOp should have same dimension. But "
107
+ " received Param dims: [%s], Moment2 dims: [%s]." ,
108
+ param_dims, ctx->GetInputDim (" Moment2" )));
109
+
110
+ ctx->SetOutputDim (" ParamOut" , param_dims);
111
+ ctx->SetOutputDim (" Moment1Out" , param_dims);
112
+ ctx->SetOutputDim (" Moment2Out" , param_dims);
113
+ ctx->SetOutputDim (" Beta1PowOut" , beta1_pow_dims);
114
+ ctx->SetOutputDim (" Beta2PowOut" , beta2_pow_dims);
115
+ }
37
116
38
- framework::OpKernelType GetKernelTypeForVar (
39
- const std::string &var_name, const framework::Tensor &tensor,
40
- const framework::OpKernelType &expected_kernel_type) const {
41
- if (var_name == " Beta1Pow" || var_name == " Beta2Pow" ||
42
- var_name == " SkipUpdate" ) {
43
- return expected_kernel_type;
44
- } else {
45
- return framework::OpKernelType (expected_kernel_type.data_type_ ,
46
- tensor.place (), tensor.layout ());
47
- }
117
+ framework::OpKernelType AdamOp::GetExpectedKernelType (
118
+ const framework::ExecutionContext &ctx) const {
119
+ auto input_data_type = OperatorWithKernel::IndicateVarDataType (ctx, " Param" );
120
+ return framework::OpKernelType (input_data_type, ctx.GetPlace ());
121
+ }
122
+
123
+ framework::OpKernelType AdamOp::GetKernelTypeForVar (
124
+ const std::string &var_name, const framework::Tensor &tensor,
125
+ const framework::OpKernelType &expected_kernel_type) const {
126
+ if (var_name == " Beta1Pow" || var_name == " Beta2Pow" ||
127
+ var_name == " SkipUpdate" ) {
128
+ return expected_kernel_type;
129
+ } else {
130
+ return framework::OpKernelType (expected_kernel_type.data_type_ ,
131
+ tensor.place (), tensor.layout ());
48
132
}
49
- };
133
+ }
50
134
51
135
class AdamOpMaker : public framework ::OpProtoAndCheckerMaker {
52
136
public:
@@ -148,10 +232,6 @@ param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsil
148
232
}
149
233
};
150
234
151
- class AdamWOp : public AdamOp {
152
- using AdamOp::AdamOp;
153
- };
154
-
155
235
class AdamWOpMaker : public AdamOpMaker {
156
236
public:
157
237
void Make () {
@@ -175,23 +255,13 @@ class AdamWOpMaker : public AdamOpMaker {
175
255
} // namespace paddle
176
256
177
257
namespace ops = paddle::operators;
258
+ REGISTER_OP_WITHOUT_GRADIENT (adam, ops::AdamOp, ops::AdamOpMaker);
259
+
260
+ REGISTER_OP_WITHOUT_GRADIENT (adamw, ops::AdamWOp, ops::AdamWOpMaker);
178
261
179
- DECLARE_INFER_SHAPE_FUNCTOR (adam, AdamInferMetaFunctor,
180
- PD_INFER_META (phi::AdamInferMeta));
181
-
182
- REGISTER_OPERATOR (
183
- adam, ops::AdamOp, ops::AdamOpMaker,
184
- paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
185
- paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
186
- AdamInferMetaFunctor);
187
-
188
- DECLARE_INFER_SHAPE_FUNCTOR (adamw, AdamwInferMetaFunctor,
189
- PD_INFER_META (phi::AdamwInferMeta));
190
- REGISTER_OPERATOR (
191
- adamw, ops::AdamWOp, ops::AdamWOpMaker,
192
- paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
193
- paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
194
- AdamwInferMetaFunctor);
262
+ REGISTER_OP_CPU_KERNEL (
263
+ adam, ops::AdamOpKernel<paddle::platform::CPUDeviceContext, float >,
264
+ ops::AdamOpKernel<paddle::platform::CPUDeviceContext, double >);
195
265
196
266
REGISTER_OP_VERSION (adam)
197
267
.AddCheckpoint(
0 commit comments