Skip to content

Commit a9f5f82

Browse files
committed
use binary search. test=develop
1 parent 3861269 commit a9f5f82

File tree

3 files changed

+335
-171
lines changed

3 files changed

+335
-171
lines changed

paddle/fluid/operators/momentum_op.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,13 @@ class MomentumOpInferVarType : public framework::VarTypeInference {
7474
framework::proto::VarType::SELECTED_ROWS) {
7575
block->FindRecursiveOrCreateVar(out_var).SetType(
7676
framework::proto::VarType::SELECTED_ROWS);
77-
} else {
77+
} else if (block->FindRecursiveOrCreateVar(input_var).GetType() ==
78+
framework::proto::VarType::LOD_TENSOR) {
7879
block->FindRecursiveOrCreateVar(out_var).SetType(
7980
framework::proto::VarType::LOD_TENSOR);
81+
} else {
82+
PADDLE_THROW(
83+
"Only support LodTensor and SelectedRows, Unexpected Input Type.");
8084
}
8185
}
8286
}
@@ -135,5 +139,6 @@ namespace ops = paddle::operators;
135139
REGISTER_OPERATOR(momentum, ops::MomentumOp, ops::MomentumOpMaker,
136140
paddle::framework::EmptyGradOpMaker,
137141
ops::MomentumOpInferVarType);
138-
REGISTER_OP_CPU_KERNEL(momentum, ops::MomentumOpKernel<float>,
139-
ops::MomentumOpKernel<double>);
142+
REGISTER_OP_CPU_KERNEL(
143+
momentum, ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, float>,
144+
ops::MomentumOpKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/momentum_op.cu

Lines changed: 3 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -15,125 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/framework/op_registry.h"
1616
#include "paddle/fluid/operators/momentum_op.h"
1717

18-
namespace paddle {
19-
namespace operators {
20-
21-
template <typename T>
22-
__global__ void MomentumKernel(const T* p, const T* g, const T* v,
23-
const T* learning_rate, const T mu,
24-
const int64_t num, bool use_nesterov, T* p_out,
25-
T* v_out) {
26-
T lr = learning_rate[0];
27-
if (use_nesterov) {
28-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
29-
i += blockDim.x * gridDim.x) {
30-
T g_val = g[i];
31-
T v_new = v[i] * mu + g_val;
32-
v_out[i] = v_new;
33-
p_out[i] = p[i] - (g_val + v_new * mu) * lr;
34-
}
35-
} else {
36-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num;
37-
i += blockDim.x * gridDim.x) {
38-
T v_new = v[i] * mu + g[i];
39-
v_out[i] = v_new;
40-
p_out[i] = p[i] - lr * v_new;
41-
}
42-
}
43-
}
44-
45-
template <typename T>
46-
__global__ void SparseMomentumKernel(const T* p, const T* g, const T* v,
47-
const T* lr, const T mu,
48-
const int64_t* grad_rows,
49-
const size_t grad_row_numel,
50-
const size_t grad_row_size,
51-
const T use_nesterov, T* p_out, T* v_out) {
52-
for (int i = blockIdx.x; i < grad_row_size; i += gridDim.x) {
53-
for (int j = threadIdx.x; j < grad_row_numel; j += blockDim.x) {
54-
size_t p_i = grad_rows[i] * grad_row_numel + j;
55-
size_t g_i = i * grad_row_numel + j;
56-
v_out[g_i] = v[g_i] * mu + g[g_i];
57-
if (use_nesterov) {
58-
p_out[p_i] = p[p_i] - (g[g_i] + v_out[g_i] * mu) * lr[0];
59-
} else {
60-
p_out[p_i] = p[p_i] - v_out[g_i] * lr[0];
61-
}
62-
}
63-
}
64-
}
65-
66-
template <typename T>
67-
class MomentumOpCUDAKernel : public framework::OpKernel<T> {
68-
public:
69-
void Compute(const framework::ExecutionContext& ctx) const override {
70-
T mu = static_cast<T>(ctx.Attr<float>("mu"));
71-
bool use_nesterov = ctx.Attr<bool>("use_nesterov");
72-
73-
auto learning_rate = ctx.Input<framework::Tensor>("LearningRate");
74-
auto param = ctx.Input<framework::Tensor>("Param");
75-
auto param_out = ctx.Output<framework::Tensor>("ParamOut");
76-
auto* velocity_var = ctx.InputVar("Velocity");
77-
auto* grad_var = ctx.InputVar("Grad");
78-
79-
if (grad_var->IsType<framework::LoDTensor>()) {
80-
PADDLE_ENFORCE(velocity_var->IsType<framework::LoDTensor>(),
81-
"Unmatched Type of Param and Grad");
82-
auto velocity = ctx.Input<framework::Tensor>("Velocity");
83-
auto grad = ctx.Input<framework::Tensor>("Grad");
84-
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
85-
T* p_out = param_out->mutable_data<T>(ctx.GetPlace());
86-
T* v_out = velocity_out->mutable_data<T>(ctx.GetPlace());
87-
auto* p = param->data<T>();
88-
auto* v = velocity->data<T>();
89-
auto* g = grad->data<T>();
90-
auto* lr = learning_rate->data<T>();
91-
92-
const int kThreadPerBlock = 256;
93-
int grid = (param->numel() + kThreadPerBlock - 1) / kThreadPerBlock;
94-
MomentumKernel<
95-
T><<<grid, kThreadPerBlock, 0, ctx.cuda_device_context().stream()>>>(
96-
p, g, v, lr, mu, param->numel(), use_nesterov, p_out, v_out);
97-
} else if (grad_var->IsType<framework::SelectedRows>()) {
98-
// sparse update embedding with selectedrows
99-
PADDLE_ENFORCE(velocity_var->IsType<framework::SelectedRows>(),
100-
"Unmatched Type of Param and Grad");
101-
auto velocity = ctx.Input<framework::SelectedRows>("Velocity");
102-
auto grad = ctx.Input<framework::SelectedRows>("Grad");
103-
auto velocity_out = ctx.Output<framework::SelectedRows>("VelocityOut");
104-
105-
// sparse update maybe empty.
106-
if (grad->rows().size() == 0) {
107-
return;
108-
}
109-
PADDLE_ENFORCE(grad->height() == velocity->height(),
110-
"Unmatched gradient and velocity.");
111-
auto* p_out = param_out->mutable_data<T>(ctx.GetPlace());
112-
auto* v_out =
113-
velocity_out->mutable_value()->mutable_data<T>(ctx.GetPlace());
114-
auto* lr = learning_rate->data<T>();
115-
auto* p = param->data<T>();
116-
auto* g = grad->value().data<T>();
117-
auto* v = velocity->value().data<T>();
118-
size_t grad_row_numel = grad->value().numel() / grad->rows().size();
119-
size_t grad_row_size = grad->rows().size();
120-
framework::Vector<int64_t> rows(grad->rows());
121-
122-
const int kThreadPerBlock = 256;
123-
int grid = (param->numel() + kThreadPerBlock - 1) / kThreadPerBlock;
124-
SparseMomentumKernel<
125-
T><<<grid, kThreadPerBlock, 0, ctx.cuda_device_context().stream()>>>(
126-
p, g, v, lr, mu, rows.CUDAData(ctx.GetPlace()), grad_row_numel,
127-
grad->rows().size(), use_nesterov, p_out, v_out);
128-
} else {
129-
PADDLE_THROW("Unsupported Variable Type of Grad");
130-
}
131-
}
132-
};
133-
134-
} // namespace operators
135-
} // namespace paddle
136-
13718
namespace ops = paddle::operators;
138-
REGISTER_OP_CUDA_KERNEL(momentum, ops::MomentumOpCUDAKernel<float>,
139-
ops::MomentumOpCUDAKernel<double>);
19+
REGISTER_OP_CUDA_KERNEL(
20+
momentum, ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, float>,
21+
ops::MomentumOpKernel<paddle::platform::CUDADeviceContext, double>);

0 commit comments

Comments
 (0)