Skip to content

Commit 3cb8da9

Browse files
authored
picked momentum fix. test=release/1.0.0 (#13955)
1 parent 229d4bb commit 3cb8da9

File tree

4 files changed

+448
-105
lines changed

4 files changed

+448
-105
lines changed

paddle/fluid/operators/momentum_op.cc

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class MomentumOp : public framework::OperatorWithKernel {
2424
using framework::OperatorWithKernel::OperatorWithKernel;
2525

2626
protected:
27-
void InferShape(framework::InferShapeContext *ctx) const override {
27+
void InferShape(framework::InferShapeContext* ctx) const override {
2828
PADDLE_ENFORCE(ctx->HasInput("Param"),
2929
"Input(param) of Momentum should not be null.");
3030
PADDLE_ENFORCE(ctx->HasInput("Grad"),
@@ -45,26 +45,50 @@ class MomentumOp : public framework::OperatorWithKernel {
4545
"Output(VelocityOut) of Momentum should not be null.");
4646

4747
auto param_dim = ctx->GetInputDim("Param");
48-
PADDLE_ENFORCE_EQ(
49-
param_dim, ctx->GetInputDim("Grad"),
50-
"Param and Grad input of MomentumOp should have the same dimension.");
51-
PADDLE_ENFORCE_EQ(
52-
param_dim, ctx->GetInputDim("Velocity"),
53-
"Param and Velocity of MomentumOp should have the same dimension.");
48+
if (ctx->GetInputsVarType("Grad")[0] ==
49+
framework::proto::VarType::LOD_TENSOR) {
50+
PADDLE_ENFORCE_EQ(
51+
param_dim, ctx->GetInputDim("Grad"),
52+
"Param and Grad input of MomentumOp should have the same dimension.");
53+
PADDLE_ENFORCE_EQ(
54+
param_dim, ctx->GetInputDim("Velocity"),
55+
"Param and Velocity of MomentumOp should have the same dimension.");
56+
}
5457
PADDLE_ENFORCE_EQ(framework::product(ctx->GetInputDim("LearningRate")), 1,
5558
"Learning_rate should be a scalar");
5659

5760
ctx->SetOutputDim("ParamOut", param_dim);
5861
ctx->SetOutputDim("VelocityOut", param_dim);
5962
}
6063
framework::OpKernelType GetExpectedKernelType(
61-
const framework::ExecutionContext &ctx) const override {
62-
auto input_data_type =
63-
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
64+
const framework::ExecutionContext& ctx) const override {
65+
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
6466
return framework::OpKernelType(input_data_type, ctx.GetPlace());
6567
}
6668
};
6769

70+
class MomentumOpInferVarType : public framework::VarTypeInference {
71+
public:
72+
void operator()(const framework::OpDesc& op_desc,
73+
framework::BlockDesc* block) const override {
74+
auto input_var = op_desc.Input("Param")[0];
75+
for (auto& out_var : op_desc.Output("ParamOut")) {
76+
if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
77+
framework::proto::VarType::SELECTED_ROWS) {
78+
block->FindRecursiveOrCreateVar(out_var).SetType(
79+
framework::proto::VarType::SELECTED_ROWS);
80+
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
81+
framework::proto::VarType::LOD_TENSOR) {
82+
block->FindRecursiveOrCreateVar(out_var).SetType(
83+
framework::proto::VarType::LOD_TENSOR);
84+
} else {
85+
PADDLE_THROW(
86+
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
87+
}
88+
}
89+
}
90+
};
91+
6892
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
6993
public:
7094
void Make() override {
@@ -115,6 +139,9 @@ else: \\
115139
} // namespace paddle
116140

117141
namespace ops = paddle::operators;
118-
REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker);
119-
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
120-
ops::MomentumOpKernel<double>);
142+
REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker,
143+
paddle::framework::EmptyGradOpMaker,
144+
ops::MomentumOpInferVarType);
145+
REGISTER_OP_CPU_KERNEL(
146+
momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
147+
ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/momentum_op.cu

Lines changed: 3 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,76 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/operators/momentum_op.h"
1717

18-
namespace paddle {
19-
namespace operators {
20-
21-
template <typename T>
22-
__global__ void MomentumKernel(const T* p, const T* g, const T* v,
23-
const T* learning_rate, const T mu,
24-
const int64_t num, bool use_nesterov, T* p_out,
25-
T* v_out) {
26-
T lr = learning_rate[0];
27-
if (use_nesterov) {
28-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
29-
i += blockDim.x * gridDim.x) {
30-
T g_val = g[i];
31-
T v_new = v[i] * mu + g_val;
32-
v_out[i] = v_new;
33-
p_out[i] = p[i] - (g_val + v_new * mu) * lr;
34-
}
35-
} else {
36-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
37-
i += blockDim.x * gridDim.x) {
38-
T v_new = v[i] * mu + g[i];
39-
v_out[i] = v_new;
40-
p_out[i] = p[i] - lr * v_new;
41-
}
42-
}
43-
}
44-
45-
template <typename T>
46-
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
47-
public:
48-
void Compute(const framework::ExecutionContext& ctx) const override {
49-
const auto* param_var = ctx.InputVar("Param");
50-
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
51-
"The Var(%s)'s type should be LoDTensor, "
52-
"but the received is %s",
53-
ctx.Inputs("Param").front(), param_var->Type().name());
54-
const auto* grad_var = ctx.InputVar("Grad");
55-
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
56-
"The Var(%s)'s type should be LoDTensor, "
57-
"but the received is %s",
58-
ctx.Inputs("Grad").front(), grad_var->Type().name());
59-
60-
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
61-
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
62-
auto param = ctx.Input<framework::Tensor>("Param");
63-
auto velocity = ctx.Input<framework::Tensor>("Velocity");
64-
auto grad = ctx.Input<framework::Tensor>("Grad");
65-
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
66-
67-
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
68-
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
69-
70-
T mu = static_cast<T>(ctx.Attr<float>("mu"));
71-
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
72-
73-
auto* p = param->data<T>();
74-
auto* v = velocity->data<T>();
75-
auto* g = grad->data<T>();
76-
auto* lr = learning_rate->data<T>();
77-
78-
int block = 512;
79-
int grid = (param->numel() + block - 1) / block;
80-
MomentumKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
81-
p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out);
82-
}
83-
};
84-
85-
} // namespace operators
86-
} // namespace paddle
87-
8818
namespace ops = paddle::operators;
89-
REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>,
90-
ops::MomentumOpCUDAKernel<double>);
19+
REGISTER_OP_CUDA_KERNEL(
20+
momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
21+
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>);

0 commit comments

Comments
 (0)