Skip to content

Commit e760641

Browse files
authored
Merge pull request #6233 from qingqing01/momentum_op
Refine and speedup momentum operator.
2 parents 3644446 + 62acf79 commit e760641

File tree

3 files changed

+76
-19
lines changed

3 files changed

+76
-19
lines changed

paddle/operators/momentum_op.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,12 @@ class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
7171
"(Tensor, default Tensor<float>) "
7272
"Input learning rate");
7373

74-
AddOutput("ParamOut", "(Tensor) Output updated parameter");
75-
AddOutput("VelocityOut", "(Tensor) Output updated velocity");
74+
AddOutput("ParamOut",
75+
"(Tensor) This output is updated parameter. "
76+
"It shared memory with Input(Param).");
77+
AddOutput("VelocityOut",
78+
"(Tensor) This output is updated velocity. "
79+
"It shared memory with Input(Velocity).");
7680

7781
AddAttr<float>("mu", "(float) Momentum coefficient");
7882
AddAttr<bool>("use_nesterov",
@@ -101,5 +105,5 @@ else: \\
101105

102106
namespace ops = paddle::operators;
103107
REGISTER_OP_WITHOUT_GRADIENT(momentum, ops::MomentumOp, ops::MomentumOpMaker);
104-
REGISTER_OP_CPU_KERNEL(
105-
momentum, ops::MomentumOpKernel<paddle::platform::CPUPlace, float>);
108+
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
109+
ops::MomentumOpKernel<double>);

paddle/operators/momentum_op.cu

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,67 @@
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#define EIGEN_USE_GPU
16-
#include "paddle/operators/momentum_op.h"
15+
#include "paddle/framework/op_registry.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
template <typename T>
21+
__global__ void MomentumKernel(const T* p, const T* g, const T* v,
22+
const T* learning_rate, const T mu,
23+
const int64_t num, bool use_nesterov, T* p_out,
24+
T* v_out) {
25+
T lr = learning_rate[0];
26+
if (use_nesterov) {
27+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
28+
i += blockDim.x * gridDim.x) {
29+
T g_val = g[i];
30+
T v_new = v[i] * mu + g_val;
31+
v_out[i] = v_new;
32+
p_out[i] = p[i] - (g_val - v_new * mu) * lr;
33+
}
34+
} else {
35+
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
36+
i += blockDim.x * gridDim.x) {
37+
T v_new = v[i] * mu + g[i];
38+
v_out[i] = v_new;
39+
p_out[i] = p[i] - lr * v_new;
40+
}
41+
}
42+
}
43+
44+
template <typename T>
45+
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
46+
public:
47+
void Compute(const framework::ExecutionContext& ctx) const override {
48+
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
49+
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
50+
auto param = ctx.Input<framework::Tensor>("Param");
51+
auto velocity = ctx.Input<framework::Tensor>("Velocity");
52+
auto grad = ctx.Input<framework::Tensor>("Grad");
53+
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
54+
55+
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
56+
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
57+
58+
T mu = static_cast<T>(ctx.Attr<float>("mu"));
59+
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
60+
61+
auto* p = param->data<T>();
62+
auto* v = velocity->data<T>();
63+
auto* g = grad->data<T>();
64+
auto* lr = learning_rate->data<T>();
65+
66+
int block = 512;
67+
int grid = (param->numel() + block - 1) / block;
68+
MomentumKernel<T><<<grid, block, 0, ctx.cuda_device_context().stream()>>>(
69+
p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out);
70+
}
71+
};
72+
73+
} // namespace operators
74+
} // namespace paddle
1775

1876
namespace ops = paddle::operators;
19-
REGISTER_OP_GPU_KERNEL(
20-
momentum, ops::MomentumOpKernel<paddle::platform::GPUPlace, float>);
77+
REGISTER_OP_GPU_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>,
78+
ops::MomentumOpCUDAKernel<double>);

paddle/operators/momentum_op.h

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace operators {
2121

22-
template <typename Place, typename T>
22+
template <typename T>
2323
class MomentumOpKernel : public framework::OpKernel<T> {
2424
public:
2525
void Compute(const framework::ExecutionContext& ctx) const override {
@@ -33,7 +33,7 @@ class MomentumOpKernel : public framework::OpKernel<T> {
3333
param_out->mutable_data<T>(ctx.GetPlace());
3434
velocity_out->mutable_data<T>(ctx.GetPlace());
3535

36-
float mu = ctx.Attr<float>("mu");
36+
T mu = static_cast<T>(ctx.Attr<float>("mu"));
3737
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
3838

3939
auto p_out = framework::EigenVector<T>::Flatten(*param_out);
@@ -42,18 +42,13 @@ class MomentumOpKernel : public framework::OpKernel<T> {
4242
auto p = framework::EigenVector<T>::Flatten(*param);
4343
auto v = framework::EigenVector<T>::Flatten(*velocity);
4444
auto g = framework::EigenVector<T>::Flatten(*grad);
45-
auto lr = framework::EigenVector<T>::Flatten(*learning_rate);
45+
auto* lr = learning_rate->data<T>();
4646

47-
auto place = ctx.GetEigenDevice<Place>();
48-
49-
Eigen::DSizes<int, 1> grad_dsize(grad->numel());
50-
51-
v_out.device(place) = v * mu + g;
47+
v_out = v * mu + g;
5248
if (use_nesterov) {
53-
p_out.device(place) = p - g * lr.broadcast(grad_dsize) +
54-
v_out * mu * lr.broadcast(grad_dsize);
49+
p_out = p - (g - v_out * mu) * lr[0];
5550
} else {
56-
p_out.device(place) = p - lr.broadcast(grad_dsize) * v_out;
51+
p_out = p - lr[0] * v_out;
5752
}
5853
}
5954
};

0 commit comments

Comments
 (0)