Skip to content

Commit 8663376

Browse files
authored
Revert "[Phi] Migrate Adam and AdamW into Phi (#40351)" (#41712)
* Revert "[Phi] Migrate Adam and AdamW into Phi (#40351)" This reverts commit 56cd340. * add infermeta
1 parent 79ceef9 commit 8663376

32 files changed

+1916
-3036
lines changed

paddle/fluid/framework/operator.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class DenseTensor;
5858
DECLARE_bool(benchmark);
5959
DECLARE_bool(check_nan_inf);
6060
DECLARE_bool(enable_unused_var_check);
61+
PADDLE_DEFINE_EXPORTED_int32(inner_op_parallelism, 0,
62+
"number of threads for inner op");
6163
DECLARE_bool(run_kp_kernel);
6264
DECLARE_bool(enable_host_event_recorder_hook);
6365

paddle/fluid/operators/optimizers/adam_op.cc

Lines changed: 114 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -12,41 +12,125 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/operators/optimizers/adam_op.h"
1516
#include "paddle/fluid/framework/op_version_registry.h"
16-
17-
#include "paddle/fluid/framework/infershape_utils.h"
18-
#include "paddle/fluid/framework/op_registry.h"
19-
#include "paddle/phi/core/infermeta_utils.h"
20-
#include "paddle/phi/infermeta/multiary.h"
17+
#include "paddle/fluid/operators/optimizers/adamw_op.h"
2118

2219
namespace paddle {
2320
namespace operators {
2421

2522
using Tensor = framework::Tensor;
2623

27-
class AdamOp : public framework::OperatorWithKernel {
28-
public:
29-
using framework::OperatorWithKernel::OperatorWithKernel;
24+
void AdamOp::InferShape(framework::InferShapeContext *ctx) const {
25+
PADDLE_ENFORCE_EQ(
26+
ctx->HasInput("Param"), true,
27+
platform::errors::NotFound("Input(Param) of AdamOp should not be null."));
28+
PADDLE_ENFORCE_EQ(
29+
ctx->HasInput("Grad"), true,
30+
platform::errors::NotFound("Input(Grad) of AdamOp should not be null."));
31+
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment1"), true,
32+
platform::errors::NotFound(
33+
"Input(Moment1) of AdamOp should not be null."));
34+
PADDLE_ENFORCE_EQ(ctx->HasInput("Moment2"), true,
35+
platform::errors::NotFound(
36+
"Input(Moment2) of AdamOp should not be null."));
37+
PADDLE_ENFORCE_EQ(ctx->HasInput("LearningRate"), true,
38+
platform::errors::NotFound(
39+
"Input(LearningRate) of AdamOp should not be null."));
40+
PADDLE_ENFORCE_EQ(ctx->HasInput("Beta1Pow"), true,
41+
platform::errors::NotFound(
42+
"Input(Beta1Pow) of AdamOp should not be null."));
43+
PADDLE_ENFORCE_EQ(ctx->HasInput("Beta2Pow"), true,
44+
platform::errors::NotFound(
45+
"Input(Beta2Pow) of AdamOp should not be null."));
46+
47+
PADDLE_ENFORCE_EQ(ctx->HasOutput("ParamOut"), true,
48+
platform::errors::NotFound(
49+
"Output(ParamOut) of AdamOp should not be null."));
50+
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment1Out"), true,
51+
platform::errors::NotFound(
52+
"Output(Moment1Out) of AdamOp should not be null."));
53+
PADDLE_ENFORCE_EQ(ctx->HasOutput("Moment2Out"), true,
54+
platform::errors::NotFound(
55+
"Output(Moment2Out) of AdamOp should not be null."));
3056

31-
framework::OpKernelType GetExpectedKernelType(
32-
const framework::ExecutionContext &ctx) const {
33-
auto input_data_type =
34-
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
35-
return framework::OpKernelType(input_data_type, ctx.GetPlace());
57+
auto lr_dims = ctx->GetInputDim("LearningRate");
58+
PADDLE_ENFORCE_NE(
59+
phi::product(lr_dims), 0,
60+
platform::errors::InvalidArgument(
61+
"The number of LearningRate shall not be 0, but received %d. Maybe "
62+
"the Input variable LearningRate has not "
63+
"been initialized. You may need to confirm "
64+
"if you put exe.run(startup_program) "
65+
"after optimizer.minimize function.",
66+
phi::product(lr_dims)));
67+
PADDLE_ENFORCE_EQ(
68+
phi::product(lr_dims), 1,
69+
platform::errors::InvalidArgument(
70+
"Learning rate should have 1 dimension, but received %d",
71+
phi::product(lr_dims)));
72+
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
73+
VLOG(3) << "dims of Beta1Pow : [" << beta1_pow_dims << "]";
74+
PADDLE_ENFORCE_GE(phi::product(beta1_pow_dims), 1,
75+
platform::errors::InvalidArgument(
76+
"The size of Beta1 power accumulator should be greater "
77+
"than 0, but received %d.",
78+
phi::product(beta1_pow_dims)));
79+
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
80+
VLOG(3) << "dims of Beta2Pow : [" << beta2_pow_dims << "]";
81+
PADDLE_ENFORCE_GE(phi::product(beta2_pow_dims), 1,
82+
platform::errors::InvalidArgument(
83+
"The size of Beta2 power accumulator should be greater "
84+
"than 0, but received %d.",
85+
phi::product(beta2_pow_dims)));
86+
87+
auto param_dims = ctx->GetInputDim("Param");
88+
if (ctx->GetInputsVarType("Grad")[0] ==
89+
framework::proto::VarType::LOD_TENSOR) {
90+
PADDLE_ENFORCE_EQ(
91+
param_dims, ctx->GetInputDim("Grad"),
92+
platform::errors::InvalidArgument(
93+
"Param and Grad input of AdamOp should have same dimension. But "
94+
"received Param dims: [%s], Grad dims: [%s].",
95+
param_dims, ctx->GetInputDim("Grad")));
3696
}
97+
PADDLE_ENFORCE_EQ(
98+
param_dims, ctx->GetInputDim("Moment1"),
99+
platform::errors::InvalidArgument(
100+
"Param and Moment1 input of AdamOp should have same dimension. But "
101+
"received Param dims: [%s], Moment1 dims: [%s].",
102+
param_dims, ctx->GetInputDim("Moment1")));
103+
PADDLE_ENFORCE_EQ(
104+
param_dims, ctx->GetInputDim("Moment2"),
105+
platform::errors::InvalidArgument(
106+
"Param and Moment2 input of AdamOp should have same dimension. But "
107+
"received Param dims: [%s], Moment2 dims: [%s].",
108+
param_dims, ctx->GetInputDim("Moment2")));
109+
110+
ctx->SetOutputDim("ParamOut", param_dims);
111+
ctx->SetOutputDim("Moment1Out", param_dims);
112+
ctx->SetOutputDim("Moment2Out", param_dims);
113+
ctx->SetOutputDim("Beta1PowOut", beta1_pow_dims);
114+
ctx->SetOutputDim("Beta2PowOut", beta2_pow_dims);
115+
}
37116

38-
framework::OpKernelType GetKernelTypeForVar(
39-
const std::string &var_name, const framework::Tensor &tensor,
40-
const framework::OpKernelType &expected_kernel_type) const {
41-
if (var_name == "Beta1Pow" || var_name == "Beta2Pow" ||
42-
var_name == "SkipUpdate") {
43-
return expected_kernel_type;
44-
} else {
45-
return framework::OpKernelType(expected_kernel_type.data_type_,
46-
tensor.place(), tensor.layout());
47-
}
117+
framework::OpKernelType AdamOp::GetExpectedKernelType(
118+
const framework::ExecutionContext &ctx) const {
119+
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
120+
return framework::OpKernelType(input_data_type, ctx.GetPlace());
121+
}
122+
123+
framework::OpKernelType AdamOp::GetKernelTypeForVar(
124+
const std::string &var_name, const framework::Tensor &tensor,
125+
const framework::OpKernelType &expected_kernel_type) const {
126+
if (var_name == "Beta1Pow" || var_name == "Beta2Pow" ||
127+
var_name == "SkipUpdate") {
128+
return expected_kernel_type;
129+
} else {
130+
return framework::OpKernelType(expected_kernel_type.data_type_,
131+
tensor.place(), tensor.layout());
48132
}
49-
};
133+
}
50134

51135
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
52136
public:
@@ -148,10 +232,6 @@ param\_out = param - learning\_rate * \frac{moment\_1}{\sqrt{moment\_2} + \epsil
148232
}
149233
};
150234

151-
class AdamWOp : public AdamOp {
152-
using AdamOp::AdamOp;
153-
};
154-
155235
class AdamWOpMaker : public AdamOpMaker {
156236
public:
157237
void Make() {
@@ -175,23 +255,13 @@ class AdamWOpMaker : public AdamOpMaker {
175255
} // namespace paddle
176256

177257
namespace ops = paddle::operators;
258+
REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
259+
260+
REGISTER_OP_WITHOUT_GRADIENT(adamw, ops::AdamWOp, ops::AdamWOpMaker);
178261

179-
DECLARE_INFER_SHAPE_FUNCTOR(adam, AdamInferMetaFunctor,
180-
PD_INFER_META(phi::AdamInferMeta));
181-
182-
REGISTER_OPERATOR(
183-
adam, ops::AdamOp, ops::AdamOpMaker,
184-
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
185-
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
186-
AdamInferMetaFunctor);
187-
188-
DECLARE_INFER_SHAPE_FUNCTOR(adamw, AdamwInferMetaFunctor,
189-
PD_INFER_META(phi::AdamwInferMeta));
190-
REGISTER_OPERATOR(
191-
adamw, ops::AdamWOp, ops::AdamWOpMaker,
192-
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
193-
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
194-
AdamwInferMetaFunctor);
262+
REGISTER_OP_CPU_KERNEL(
263+
adam, ops::AdamOpKernel<paddle::platform::CPUDeviceContext, float>,
264+
ops::AdamOpKernel<paddle::platform::CPUDeviceContext, double>);
195265

196266
REGISTER_OP_VERSION(adam)
197267
.AddCheckpoint(

0 commit comments

Comments
 (0)