Skip to content

Commit 57bbee6

Browse files
authored
Merge branch 'develop' into cmake_speed
2 parents 0968c7c + d7bf372 commit 57bbee6

File tree

7 files changed

+386
-38
lines changed

7 files changed

+386
-38
lines changed

paddle/operators/CMakeLists.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,20 @@ set(DEPS_OPS
183183
array_to_lod_tensor_op
184184
lstm_op
185185
tensor_array_read_write_op
186-
gru_op)
186+
gru_op
187+
adagrad_op
188+
sgd_op)
189+
187190

188191
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
189192
op_library(cross_entropy_op DEPS cross_entropy)
190193
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
191194
op_library(softmax_op DEPS softmax)
192195
op_library(sequence_softmax_op DEPS softmax)
196+
op_library(sum_op DEPS selected_rows_functor)
197+
op_library(sgd_op DEPS selected_rows_functor)
198+
op_library(adagrad_op DEPS selected_rows_functor)
193199
op_library(conv_op DEPS vol2col)
194-
op_library(sum_op DEPS net_op selected_rows_functor)
195200
op_library(pool_op DEPS pooling)
196201
op_library(pool_with_index_op DEPS pooling)
197202
op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table)

paddle/operators/adagrad_op.cc

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@ limitations under the License. */
1414

1515
#include "paddle/operators/adagrad_op.h"
1616

17+
#include <cmath>
18+
19+
#include "paddle/operators/math/math_function.h"
20+
#include "paddle/operators/math/selected_rows_functor.h"
21+
1722
namespace paddle {
1823
namespace operators {
1924

2025
class AdagradOp : public framework::OperatorWithKernel {
2126
public:
2227
using framework::OperatorWithKernel::OperatorWithKernel;
2328

24-
void InferShape(framework::InferShapeContext *ctx) const override {
29+
void InferShape(framework::InferShapeContext* ctx) const override {
2530
PADDLE_ENFORCE(ctx->HasInput("Param"),
2631
"Input(Param) of AdagradOp should not be null.");
2732
PADDLE_ENFORCE(ctx->HasInput("Grad"),
@@ -54,8 +59,8 @@ class AdagradOp : public framework::OperatorWithKernel {
5459

5560
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
5661
public:
57-
AdagradOpMaker(framework::OpProto *proto,
58-
framework::OpAttrChecker *op_checker)
62+
AdagradOpMaker(framework::OpProto* proto,
63+
framework::OpAttrChecker* op_checker)
5964
: OpProtoAndCheckerMaker(proto, op_checker) {
6065
AddInput("Param", "(Tensor) Input parameter");
6166
AddInput("Grad", "(Tensor) Input gradient");
@@ -87,10 +92,85 @@ for numerical stability to avoid the division by zero error.
8792
)DOC");
8893
}
8994
};
95+
96+
namespace {
97+
size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
98+
return std::find(rows.begin(), rows.end(), value) - rows.begin();
99+
}
100+
} // namespace
101+
102+
template <typename T>
103+
struct SparseAdagradFunctor<platform::CPUPlace, T> {
104+
void operator()(const platform::DeviceContext& context,
105+
const framework::SelectedRows& grad,
106+
const framework::Tensor& learning_rate, T epsilon,
107+
framework::Tensor* moment, framework::Tensor* param) {
108+
// 1. g_m.rows = set(g.rows)
109+
auto grad_rows = grad.rows();
110+
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
111+
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
112+
113+
auto grad_width = grad.value().dims()[1];
114+
std::unique_ptr<framework::SelectedRows> grad_merge{
115+
new framework::SelectedRows()};
116+
grad_merge->set_rows(merge_rows);
117+
grad_merge->set_height(grad.height());
118+
grad_merge->mutable_value()->mutable_data<T>(
119+
framework::make_ddim(
120+
{static_cast<int64_t>(merge_rows.size()), grad_width}),
121+
context.GetPlace());
122+
123+
math::SetConstant<platform::CPUPlace, T> constant_functor;
124+
constant_functor(context, grad_merge->mutable_value(), 0.0);
125+
126+
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
127+
auto* grad_data = grad.value().data<T>();
128+
129+
for (size_t i = 0; i < grad_rows.size(); i++) {
130+
size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]);
131+
for (int64_t j = 0; j < grad_width; j++) {
132+
grad_merge_data[grad_merge_i * grad_width + j] +=
133+
grad_data[i * grad_width + j];
134+
}
135+
}
136+
137+
// 2. m += g_m * g_m
138+
std::unique_ptr<framework::SelectedRows> grad_square{
139+
new framework::SelectedRows()};
140+
grad_square->set_rows(grad_merge->rows());
141+
grad_square->set_height(grad_merge->height());
142+
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
143+
context.GetPlace());
144+
auto gs =
145+
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
146+
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
147+
gs.device(*context.GetEigenDevice<platform::CPUPlace>()) = gm * gm;
148+
149+
math::SelectedRowsAddToTensor<platform::CPUPlace, T> functor;
150+
functor(context, *grad_square, moment);
151+
152+
// 3. update parameter
153+
auto* lr = learning_rate.data<T>();
154+
auto* param_data = param->data<T>();
155+
auto* moment_data = moment->data<T>();
156+
157+
for (size_t i = 0; i < merge_rows.size(); i++) {
158+
for (int64_t j = 0; j < grad_width; j++) {
159+
param_data[merge_rows[i] * grad_width + j] -=
160+
lr[0] * grad_merge_data[i * grad_width + j] /
161+
(std::sqrt(moment_data[merge_rows[i] * grad_width + j]) + epsilon);
162+
}
163+
}
164+
}
165+
};
166+
167+
template struct SparseAdagradFunctor<platform::CPUPlace, float>;
168+
template struct SparseAdagradFunctor<platform::CPUPlace, double>;
90169
} // namespace operators
91170
} // namespace paddle
92171

93172
namespace ops = paddle::operators;
94173
REGISTER_OP_WITHOUT_GRADIENT(adagrad, ops::AdagradOp, ops::AdagradOpMaker);
95-
REGISTER_OP_CPU_KERNEL(adagrad,
96-
ops::AdagradOpKernel<paddle::platform::CPUPlace, float>);
174+
REGISTER_OP_CPU_KERNEL(
175+
adagrad, ops::AdagradOpKernel<paddle::platform::CPUPlace, float>,
176+
ops::AdagradOpKernel<paddle::platform::CPUPlace, double>);

paddle/operators/adagrad_op.cu

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,138 @@
1414

1515
#define EIGEN_USE_GPU
1616
#include "paddle/operators/adagrad_op.h"
17+
#include "paddle/operators/math/selected_rows_functor.h"
18+
#include "paddle/operators/math/math_function.h"
19+
#include "paddle/platform/cuda_helper.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
namespace {
25+
26+
template <typename T, int block_size>
27+
__global__ void MergeGradKernel(const T* grad, const int64_t* grad_rows,
28+
T* grad_merge, const int64_t* grad_merge_rows,
29+
size_t grad_merge_rows_size,
30+
int64_t row_numel) {
31+
const int ty = blockIdx.y;
32+
int tid = threadIdx.x;
33+
__shared__ size_t grad_merge_idx;
34+
35+
if (tid == 0) {
36+
for (size_t i = 0; i < grad_merge_rows_size; i++) {
37+
if (grad_rows[ty] == grad_merge_rows[i]) {
38+
grad_merge_idx = i;
39+
}
40+
}
41+
}
42+
43+
__syncthreads();
44+
45+
grad += ty * row_numel;
46+
grad_merge += grad_merge_idx * row_numel;
47+
for (int index = tid; index < row_numel; index += block_size) {
48+
paddle::platform::CudaAtomicAdd(grad_merge + index, grad[index]);
49+
}
50+
}
51+
52+
template <typename T, int block_size>
53+
__global__ void SparseAdagradFunctorKernel(const T* grad, const int64_t* rows,
54+
const T* learning_rate, T* param,
55+
T* moment, int64_t row_numel,
56+
T epsilon) {
57+
const int ty = blockIdx.y;
58+
int tid = threadIdx.x;
59+
60+
grad += ty * row_numel;
61+
param += rows[ty] * row_numel;
62+
moment += rows[ty] * row_numel;
63+
64+
for (int index = tid; index < row_numel; index += block_size) {
65+
// Since index in rows of SelectedRows can be duplicate, we have to use
66+
// Atomic Operation to avoid concurrent write error.
67+
paddle::platform::CudaAtomicAdd(param + index,
68+
-1.0 * learning_rate[0] * grad[index] /
69+
(sqrt(moment[index]) + epsilon));
70+
}
71+
}
72+
} // namespace
73+
74+
template <typename T>
75+
struct SparseAdagradFunctor<platform::GPUPlace, T> {
76+
void operator()(const platform::DeviceContext& context,
77+
const framework::SelectedRows& grad,
78+
const framework::Tensor& learning_rate, T epsilon,
79+
framework::Tensor* moment, framework::Tensor* param) {
80+
// 1. g_m.rows = set(g.rows)
81+
auto grad_rows = grad.rows();
82+
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
83+
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
84+
85+
auto grad_width = grad.value().dims()[1];
86+
std::unique_ptr<framework::SelectedRows> grad_merge{
87+
new framework::SelectedRows()};
88+
grad_merge->set_rows(merge_rows);
89+
grad_merge->set_height(grad.height());
90+
grad_merge->mutable_value()->mutable_data<T>(
91+
framework::make_ddim(
92+
{static_cast<int64_t>(merge_rows.size()), grad_width}),
93+
context.GetPlace());
94+
95+
math::SetConstant<platform::GPUPlace, T> constant_functor;
96+
constant_functor(context, grad_merge->mutable_value(), 0.0);
97+
98+
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
99+
auto* grad_data = grad.value().data<T>();
100+
101+
const int block_size = 256;
102+
dim3 threads(block_size, 1);
103+
dim3 grid1(1, grad_rows.size());
104+
105+
MergeGradKernel<
106+
T, 256><<<grid1, threads, 0,
107+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
108+
.stream()>>>(grad_data, grad.rows().data(),
109+
grad_merge_data, grad_merge->rows().data(),
110+
grad_merge->rows().size(), grad_width);
111+
112+
// 2. m += g_m * g_m
113+
std::unique_ptr<framework::SelectedRows> grad_square{
114+
new framework::SelectedRows()};
115+
grad_square->set_rows(grad_merge->rows());
116+
grad_square->set_height(grad_merge->height());
117+
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
118+
context.GetPlace());
119+
auto gs =
120+
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
121+
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
122+
gs.device(*context.GetEigenDevice<platform::GPUPlace>()) = gm * gm;
123+
124+
math::SelectedRowsAddToTensor<platform::GPUPlace, T> functor;
125+
functor(context, *grad_square, moment);
126+
127+
// 3. update parameter
128+
auto* lr = learning_rate.data<T>();
129+
auto* param_data = param->data<T>();
130+
auto* moment_data = moment->data<T>();
131+
132+
dim3 grid2(1, merge_rows.size());
133+
SparseAdagradFunctorKernel<
134+
T, 256><<<grid2, threads, 0,
135+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
136+
.stream()>>>(grad_merge_data, grad_merge->rows().data(),
137+
lr, param_data,
138+
moment_data, grad_width, epsilon);
139+
}
140+
};
141+
142+
template struct SparseAdagradFunctor<platform::GPUPlace, float>;
143+
template struct SparseAdagradFunctor<platform::GPUPlace, double>;
144+
145+
} // namespace operators
146+
} // namespace paddle
17147

18148
namespace ops = paddle::operators;
19-
REGISTER_OP_GPU_KERNEL(adagrad,
20-
ops::AdagradOpKernel<paddle::platform::GPUPlace, float>);
149+
REGISTER_OP_GPU_KERNEL(
150+
adagrad, ops::AdagradOpKernel<paddle::platform::GPUPlace, float>,
151+
ops::AdagradOpKernel<paddle::platform::GPUPlace, double>);

paddle/operators/adagrad_op.h

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,59 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace operators {
2121

22+
template <typename Place, typename T>
23+
struct SparseAdagradFunctor {
24+
void operator()(const platform::DeviceContext& context,
25+
const framework::SelectedRows& grad,
26+
const framework::Tensor& learning_rate, T epsilon,
27+
framework::Tensor* moment, framework::Tensor* param);
28+
};
29+
2230
template <typename Place, typename T>
2331
class AdagradOpKernel : public framework::OpKernel<T> {
2432
public:
2533
void Compute(const framework::ExecutionContext& ctx) const override {
26-
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
27-
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
34+
auto* param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
35+
auto* moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
2836

2937
param_out_tensor->mutable_data<T>(ctx.GetPlace());
3038
moment_out_tensor->mutable_data<T>(ctx.GetPlace());
3139

32-
float epsilon = ctx.Attr<float>("epsilon");
33-
34-
auto param = framework::EigenVector<T>::Flatten(
35-
*ctx.Input<framework::Tensor>("Param"));
36-
auto grad = framework::EigenVector<T>::Flatten(
37-
*ctx.Input<framework::Tensor>("Grad"));
38-
auto moment = framework::EigenVector<T>::Flatten(
39-
*ctx.Input<framework::Tensor>("Moment"));
40-
auto lr = framework::EigenVector<T>::Flatten(
41-
*ctx.Input<framework::Tensor>("LearningRate"));
42-
43-
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
44-
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
45-
auto place = ctx.GetEigenDevice<Place>();
46-
47-
moment_out.device(place) = moment + grad * grad;
48-
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
49-
param_out.device(place) =
50-
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
40+
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
41+
42+
auto* grad_var = ctx.InputVar("Grad");
43+
if (grad_var->IsType<framework::LoDTensor>()) {
44+
auto param = framework::EigenVector<T>::Flatten(
45+
*ctx.Input<framework::Tensor>("Param"));
46+
auto grad = framework::EigenVector<T>::Flatten(
47+
*ctx.Input<framework::Tensor>("Grad"));
48+
auto moment = framework::EigenVector<T>::Flatten(
49+
*ctx.Input<framework::Tensor>("Moment"));
50+
auto lr = framework::EigenVector<T>::Flatten(
51+
*ctx.Input<framework::Tensor>("LearningRate"));
52+
53+
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
54+
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
55+
auto place = ctx.GetEigenDevice<Place>();
56+
57+
moment_out.device(place) = moment + grad * grad;
58+
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
59+
param_out.device(place) =
60+
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
61+
} else if (grad_var->IsType<framework::SelectedRows>()) {
62+
auto* param_tensor = ctx.Input<framework::Tensor>("Param");
63+
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
64+
65+
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment");
66+
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
67+
68+
SparseAdagradFunctor<Place, T> functor;
69+
functor(ctx.device_context(), *ctx.Input<framework::SelectedRows>("Grad"),
70+
*ctx.Input<framework::Tensor>("LearningRate"), epsilon,
71+
moment_out_tensor, param_out_tensor);
72+
} else {
73+
PADDLE_THROW("Unsupported Variable Type of Grad");
74+
}
5175
}
5276
};
5377

paddle/operators/sgd_op.cu

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ namespace paddle {
2020
namespace operators {
2121

2222
namespace {
23-
template <typename T>
23+
template <typename T, int block_size>
2424
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
2525
const int64_t* rows,
2626
const T* learning_rate, T* tensor_out,
27-
int64_t row_numel, int block_size) {
27+
int64_t row_numel) {
2828
const int ty = blockIdx.y;
2929
int tid = threadIdx.x;
3030

@@ -59,14 +59,15 @@ struct SparseSGDFunctor<platform::GPUPlace, T> {
5959
auto* in_data = in_value.data<T>();
6060
auto* out_data = output->data<T>();
6161

62-
int block_size = 256;
62+
const int block_size = 256;
6363
dim3 threads(block_size, 1);
6464
dim3 grid(1, in_rows.size());
6565
SparseSGDFunctorKernel<
66-
T><<<grid, threads, 0,
67-
reinterpret_cast<const platform::CUDADeviceContext&>(context)
68-
.stream()>>>(in_data, in_rows.data(), learning_rate.data<T>(),
69-
out_data, in_row_numel, block_size);
66+
T, 256><<<grid, threads, 0,
67+
reinterpret_cast<const platform::CUDADeviceContext&>(context)
68+
.stream()>>>(in_data, in_rows.data(),
69+
learning_rate.data<T>(), out_data,
70+
in_row_numel);
7071
}
7172
};
7273

0 commit comments

Comments
 (0)