Skip to content

Commit 6330c14

Browse files
committed
fix sparse rmsprop
1 parent cb14b0d commit 6330c14

File tree

3 files changed

+276
-57
lines changed

3 files changed

+276
-57
lines changed

paddle/fluid/operators/adam_op.h

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License. */
1818
#include <vector>
1919
#include "paddle/fluid/framework/op_registry.h"
2020
#include "paddle/fluid/operators/detail/safe_ref.h"
21+
#include "paddle/fluid/operators/math/algorithm.h"
2122
#include "paddle/fluid/operators/math/selected_rows_functor.h"
2223
#include "paddle/fluid/platform/for_range.h"
2324

@@ -199,23 +200,9 @@ struct SparseAdamFunctor {
199200
row_numel_(row_numel),
200201
row_count_(row_count) {}
201202

202-
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
203-
int64_t beg = 0, end = row_count_ - 1;
204-
while (beg <= end) {
205-
auto mid = ((beg + end) >> 1);
206-
if (rows_[mid] == row)
207-
return mid;
208-
else if (rows_[mid] < row)
209-
beg = mid + 1;
210-
else
211-
end = mid - 1;
212-
}
213-
return -1;
214-
}
215-
216203
inline HOSTDEVICE void operator()(size_t i) const {
217-
int64_t row = i / row_numel_;
218-
auto row_idx = BinarySearchInRows(row);
204+
auto row_idx =
205+
math::BinarySearch<int64_t>(rows_, row_count_, i / row_numel_);
219206
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
220207

221208
// The following code is the same as dense
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <cstdint> // for int64_t
19+
#include <numeric>
20+
21+
#include "paddle/fluid/platform/hostdevice.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
namespace math {
26+
27+
template <typename T>
28+
HOSTDEVICE inline int64_t BinarySearch(const T *x, int64_t num, const T &val) {
29+
int64_t beg = 0, end = num - 1;
30+
while (beg <= end) {
31+
auto mid = ((beg + end) >> 1);
32+
if (x[mid] == val)
33+
return mid;
34+
else if (x[mid] < val)
35+
beg = mid + 1;
36+
else
37+
end = mid - 1;
38+
}
39+
return -1;
40+
}
41+
42+
} // namespace math
43+
} // namespace operators
44+
} // namespace paddle

paddle/fluid/operators/rmsprop_op.h

Lines changed: 229 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,66 +13,254 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#pragma once
16+
#include <math.h>
1617
#include "paddle/fluid/framework/eigen.h"
1718
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/algorithm.h"
20+
#include "paddle/fluid/operators/math/selected_rows_functor.h"
21+
#include "paddle/fluid/platform/for_range.h"
1822

1923
namespace paddle {
2024
namespace operators {
2125

22-
using Tensor = framework::Tensor;
2326
template <typename T, int MajorType = Eigen::RowMajor,
2427
typename IndexType = Eigen::DenseIndex>
2528
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
2629

30+
template <typename T>
31+
struct DenseRmspropGradFunctor {
32+
inline explicit DenseRmspropGradFunctor(const T *grad) : grad_(grad) {}
33+
34+
HOSTDEVICE inline T operator()(int64_t idx) const { return grad_[idx]; }
35+
36+
const T *grad_;
37+
};
38+
39+
template <typename T>
40+
struct SparseRmspropGradFunctor {
41+
inline SparseRmspropGradFunctor(const T *grad, const int64_t *rows,
42+
int64_t row_numel, int64_t row_count)
43+
: grad_(grad),
44+
rows_(rows),
45+
row_numel_(row_numel),
46+
row_count_(row_count) {}
47+
48+
HOSTDEVICE inline T operator()(int64_t idx) const {
49+
auto row_idx = math::BinarySearch(rows_, row_count_, idx / row_numel_);
50+
return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0;
51+
}
52+
53+
const T *grad_;
54+
const int64_t *rows_;
55+
int64_t row_numel_;
56+
int64_t row_count_;
57+
};
58+
59+
template <typename T, typename GradFunctor>
60+
struct UncenteredRmspropFunctor {
61+
UncenteredRmspropFunctor(T *param, T *ms, T *mom, const T *lr, T rho,
62+
T epsilon, T momentum,
63+
const GradFunctor &grad_functor)
64+
: param_(param),
65+
ms_(ms),
66+
mom_(mom),
67+
lr_(lr),
68+
rho_(rho),
69+
epsilon_(epsilon),
70+
momentum_(momentum),
71+
grad_functor_(grad_functor) {}
72+
73+
HOSTDEVICE inline void operator()(int64_t idx) const {
74+
T g = grad_functor_(idx);
75+
T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
76+
T mom_out = momentum_ * mom_[idx] + lr_[0] * g / sqrt(ms_out + epsilon_);
77+
param_[idx] -= mom_out;
78+
ms_[idx] = ms_out;
79+
mom_[idx] = mom_out;
80+
}
81+
82+
T *param_;
83+
T *ms_;
84+
T *mom_;
85+
const T *lr_;
86+
T rho_;
87+
T epsilon_;
88+
T momentum_;
89+
GradFunctor grad_functor_;
90+
};
91+
92+
template <typename T, typename GradFunctor>
93+
struct CenteredRmspropFunctor {
94+
CenteredRmspropFunctor(T *param, T *ms, T *mom, T *mean_grad, const T *lr,
95+
T rho, T epsilon, T momentum,
96+
const GradFunctor &grad_functor)
97+
: param_(param),
98+
ms_(ms),
99+
mom_(mom),
100+
mean_grad_(mean_grad),
101+
lr_(lr),
102+
rho_(rho),
103+
epsilon_(epsilon),
104+
momentum_(momentum),
105+
grad_functor_(grad_functor) {}
106+
107+
HOSTDEVICE inline void operator()(int64_t idx) const {
108+
T g = grad_functor_(idx);
109+
T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
110+
T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g;
111+
T mom_out = momentum_ * mom_[idx] +
112+
lr_[0] * g / sqrt(ms_out - mg_out * mg_out + epsilon_);
113+
param_[idx] -= mom_out;
114+
ms_[idx] = ms_out;
115+
mom_[idx] = mom_out;
116+
mean_grad_[idx] = mg_out;
117+
}
118+
119+
T *param_;
120+
T *ms_;
121+
T *mom_;
122+
T *mean_grad_;
123+
const T *lr_;
124+
T rho_;
125+
T epsilon_;
126+
T momentum_;
127+
GradFunctor grad_functor_;
128+
};
129+
27130
template <typename DeviceContext, typename T>
28131
class RmspropOpKernel : public framework::OpKernel<T> {
29132
public:
30-
void Compute(const framework::ExecutionContext& ctx) const override {
31-
auto* param_out = ctx.Output<Tensor>("ParamOut");
32-
auto* moment_out = ctx.Output<Tensor>("MomentOut");
33-
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
133+
void Compute(const framework::ExecutionContext &ctx) const override {
134+
using Tensor = framework::LoDTensor;
135+
auto *grad_var = ctx.InputVar("Grad");
136+
auto *param_out = ctx.Output<Tensor>("ParamOut");
137+
auto *moment_out = ctx.Output<Tensor>("MomentOut");
138+
auto *mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
34139

35-
auto grad = ctx.Input<Tensor>("Grad");
140+
auto epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
141+
auto rho = static_cast<T>(ctx.Attr<float>("decay"));
142+
auto momentum = static_cast<T>(ctx.Attr<float>("momentum"));
143+
bool centered = ctx.Attr<bool>("centered");
36144

37-
param_out->mutable_data<T>(ctx.GetPlace());
38-
moment_out->mutable_data<T>(ctx.GetPlace());
39-
mean_square_out->mutable_data<T>(ctx.GetPlace());
145+
auto &p_tensor = *ctx.Input<Tensor>("Param");
146+
auto &ms_tensor = *ctx.Input<Tensor>("MeanSquare");
147+
auto &lr_tensor = *ctx.Input<Tensor>("LearningRate");
148+
auto &mom_tensor = *ctx.Input<Tensor>("Moment");
40149

41-
float epsilon = ctx.Attr<float>("epsilon");
42-
float rho = ctx.Attr<float>("decay");
43-
float momentum = ctx.Attr<float>("momentum");
44-
bool centered = ctx.Attr<bool>("centered");
150+
PADDLE_ENFORCE_EQ(&p_tensor, param_out,
151+
"Param and ParamOut must be the same Tensor");
152+
PADDLE_ENFORCE_EQ(&mom_tensor, moment_out,
153+
"Moment and MomentOut must be the same Tensor");
154+
PADDLE_ENFORCE_EQ(&ms_tensor, mean_square_out,
155+
"MeanSquare and MeanSquareOut must be the same Tensor");
156+
157+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
158+
size_t limit = static_cast<size_t>(ms_tensor.numel());
159+
160+
if (grad_var->IsType<Tensor>()) {
161+
auto &grad_tensor = grad_var->Get<Tensor>();
162+
163+
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value) {
164+
auto &place =
165+
*ctx.template device_context<DeviceContext>().eigen_device();
166+
auto lr_value = lr_tensor.data<T>()[0];
167+
168+
auto p = EigenVector<T>::Flatten(p_tensor);
169+
auto ms = EigenVector<T>::Flatten(ms_tensor);
170+
auto g = EigenVector<T>::Flatten(grad_tensor);
171+
auto mom = EigenVector<T>::Flatten(mom_tensor);
172+
173+
auto p_out = EigenVector<T>::Flatten(*param_out);
174+
auto mom_out = EigenVector<T>::Flatten(*moment_out);
175+
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
176+
177+
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
178+
if (centered) {
179+
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
180+
auto mg = EigenVector<T>::Flatten(mg_tensor);
181+
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
182+
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
183+
"MeanGrad and MeanGradOut must be the same Tensor");
184+
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
185+
186+
mg_out.device(place) = rho * mg + (1 - rho) * g;
187+
mom_out.device(place) =
188+
momentum * mom +
189+
lr_value * g / (ms_out - mg_out.square() + epsilon).sqrt();
190+
} else {
191+
mom_out.device(place) =
192+
momentum * mom + lr_value * g / (ms_out + epsilon).sqrt();
193+
}
194+
p_out.device(place) = p - mom_out;
195+
} else {
196+
DenseRmspropGradFunctor<T> grad_func(grad_tensor.data<T>());
197+
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
198+
if (centered) {
199+
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
200+
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
201+
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
202+
"MeanGrad and MeanGradOut must be the same Tensor");
203+
for_range(CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
204+
param_out->mutable_data<T>(ctx.GetPlace()),
205+
mean_square_out->mutable_data<T>(ctx.GetPlace()),
206+
moment_out->mutable_data<T>(ctx.GetPlace()),
207+
mean_grad_out->mutable_data<T>(ctx.GetPlace()),
208+
lr_tensor.data<T>(), rho, epsilon, momentum, grad_func));
209+
} else {
210+
for_range(UncenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
211+
param_out->mutable_data<T>(ctx.GetPlace()),
212+
mean_square_out->mutable_data<T>(ctx.GetPlace()),
213+
moment_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
214+
rho, epsilon, momentum, grad_func));
215+
}
216+
}
217+
} else if (grad_var->IsType<framework::SelectedRows>()) {
218+
auto &grad = grad_var->Get<framework::SelectedRows>();
219+
auto *merged_grad = const_cast<framework::Scope &>(ctx.scope())
220+
.Var()
221+
->GetMutable<framework::SelectedRows>();
222+
223+
math::scatter::MergeAdd<DeviceContext, T> merge_func;
224+
merge_func(dev_ctx, grad, merged_grad);
225+
226+
platform::ForRange<DeviceContext> for_range(dev_ctx, limit);
227+
const int64_t *rows;
228+
#ifdef PADDLE_WITH_CUDA
229+
if (platform::is_gpu_place(ctx.GetPlace())) {
230+
rows = merged_grad->rows().CUDAData(ctx.GetPlace());
231+
} else {
232+
#endif
233+
rows = merged_grad->rows().data();
234+
#ifdef PADDLE_WITH_CUDA
235+
}
236+
#endif
237+
auto &merged_tensor = merged_grad->value();
238+
int64_t row_count = merged_grad->rows().size();
239+
int64_t row_numel = merged_tensor.numel() / row_count;
240+
SparseRmspropGradFunctor<T> grad_func(merged_tensor.data<T>(), rows,
241+
row_numel, row_count);
45242

46-
auto p = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Param"));
47-
auto ms = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanSquare"));
48-
auto lr = EigenVector<T>::Flatten(*ctx.Input<Tensor>("LearningRate"));
49-
auto g = EigenVector<T>::Flatten(*grad);
50-
auto mom = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Moment"));
51-
52-
auto p_out = EigenVector<T>::Flatten(*param_out);
53-
auto mom_out = EigenVector<T>::Flatten(*moment_out);
54-
auto ms_out = EigenVector<T>::Flatten(*mean_square_out);
55-
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
56-
57-
Eigen::DSizes<int, 1> grad_dsize(static_cast<int>(grad->numel()));
58-
59-
ms_out.device(place) = rho * ms + (1 - rho) * g * g;
60-
if (centered) {
61-
auto mg = EigenVector<T>::Flatten(*ctx.Input<Tensor>("MeanGrad"));
62-
auto* mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
63-
mean_grad_out->mutable_data<T>(ctx.GetPlace());
64-
auto mg_out = EigenVector<T>::Flatten(*mean_grad_out);
65-
66-
mg_out.device(place) = rho * mg + (1 - rho) * g;
67-
mom_out.device(place) = momentum * mom +
68-
lr.broadcast(grad_dsize) * g /
69-
(ms_out - mg_out.square() + epsilon).sqrt();
243+
if (centered) {
244+
auto &mg_tensor = *ctx.Input<Tensor>("MeanGrad");
245+
auto *mean_grad_out = ctx.Output<Tensor>("MeanGradOut");
246+
PADDLE_ENFORCE(&mg_tensor, mean_grad_out,
247+
"MeanGrad and MeanGradOut must be the same Tensor");
248+
for_range(CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
249+
param_out->mutable_data<T>(ctx.GetPlace()),
250+
mean_square_out->mutable_data<T>(ctx.GetPlace()),
251+
moment_out->mutable_data<T>(ctx.GetPlace()),
252+
mean_grad_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
253+
rho, epsilon, momentum, grad_func));
254+
} else {
255+
for_range(UncenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
256+
param_out->mutable_data<T>(ctx.GetPlace()),
257+
mean_square_out->mutable_data<T>(ctx.GetPlace()),
258+
moment_out->mutable_data<T>(ctx.GetPlace()), lr_tensor.data<T>(),
259+
rho, epsilon, momentum, grad_func));
260+
}
70261
} else {
71-
mom_out.device(place) =
72-
momentum * mom +
73-
lr.broadcast(grad_dsize) * g / (ms_out + epsilon).sqrt();
262+
PADDLE_THROW("RMSProp only supports LoDTensor or SelectedRows gradient");
74263
}
75-
p_out.device(place) = p - mom_out;
76264
}
77265
};
78266

0 commit comments

Comments
 (0)