@@ -13,66 +13,254 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#pragma once
16
+ #include < math.h>
16
17
#include " paddle/fluid/framework/eigen.h"
17
18
#include " paddle/fluid/framework/op_registry.h"
19
+ #include " paddle/fluid/operators/math/algorithm.h"
20
+ #include " paddle/fluid/operators/math/selected_rows_functor.h"
21
+ #include " paddle/fluid/platform/for_range.h"
18
22
19
23
namespace paddle {
20
24
namespace operators {
21
25
22
- using Tensor = framework::Tensor;
23
26
template <typename T, int MajorType = Eigen::RowMajor,
24
27
typename IndexType = Eigen::DenseIndex>
25
28
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
26
29
30
+ template <typename T>
31
+ struct DenseRmspropGradFunctor {
32
+ inline explicit DenseRmspropGradFunctor (const T *grad) : grad_(grad) {}
33
+
34
+ HOSTDEVICE inline T operator ()(int64_t idx) const { return grad_[idx]; }
35
+
36
+ const T *grad_;
37
+ };
38
+
39
+ template <typename T>
40
+ struct SparseRmspropGradFunctor {
41
+ inline SparseRmspropGradFunctor (const T *grad, const int64_t *rows,
42
+ int64_t row_numel, int64_t row_count)
43
+ : grad_(grad),
44
+ rows_(rows),
45
+ row_numel_(row_numel),
46
+ row_count_(row_count) {}
47
+
48
+ HOSTDEVICE inline T operator ()(int64_t idx) const {
49
+ auto row_idx = math::BinarySearch (rows_, row_count_, idx / row_numel_);
50
+ return row_idx >= 0 ? grad_[row_idx * row_numel_ + idx % row_numel_] : 0 ;
51
+ }
52
+
53
+ const T *grad_;
54
+ const int64_t *rows_;
55
+ int64_t row_numel_;
56
+ int64_t row_count_;
57
+ };
58
+
59
+ template <typename T, typename GradFunctor>
60
+ struct UncenteredRmspropFunctor {
61
+ UncenteredRmspropFunctor (T *param, T *ms, T *mom, const T *lr, T rho,
62
+ T epsilon, T momentum,
63
+ const GradFunctor &grad_functor)
64
+ : param_(param),
65
+ ms_ (ms),
66
+ mom_(mom),
67
+ lr_(lr),
68
+ rho_(rho),
69
+ epsilon_(epsilon),
70
+ momentum_(momentum),
71
+ grad_functor_(grad_functor) {}
72
+
73
+ HOSTDEVICE inline void operator ()(int64_t idx) const {
74
+ T g = grad_functor_ (idx);
75
+ T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
76
+ T mom_out = momentum_ * mom_[idx] + lr_[0 ] * g / sqrt (ms_out + epsilon_);
77
+ param_[idx] -= mom_out;
78
+ ms_[idx] = ms_out;
79
+ mom_[idx] = mom_out;
80
+ }
81
+
82
+ T *param_;
83
+ T *ms_;
84
+ T *mom_;
85
+ const T *lr_;
86
+ T rho_;
87
+ T epsilon_;
88
+ T momentum_;
89
+ GradFunctor grad_functor_;
90
+ };
91
+
92
+ template <typename T, typename GradFunctor>
93
+ struct CenteredRmspropFunctor {
94
+ CenteredRmspropFunctor (T *param, T *ms, T *mom, T *mean_grad, const T *lr,
95
+ T rho, T epsilon, T momentum,
96
+ const GradFunctor &grad_functor)
97
+ : param_(param),
98
+ ms_ (ms),
99
+ mom_(mom),
100
+ mean_grad_(mean_grad),
101
+ lr_(lr),
102
+ rho_(rho),
103
+ epsilon_(epsilon),
104
+ momentum_(momentum),
105
+ grad_functor_(grad_functor) {}
106
+
107
+ HOSTDEVICE inline void operator ()(int64_t idx) const {
108
+ T g = grad_functor_ (idx);
109
+ T ms_out = rho_ * ms_[idx] + (1 - rho_) * g * g;
110
+ T mg_out = rho_ * mean_grad_[idx] + (1 - rho_) * g;
111
+ T mom_out = momentum_ * mom_[idx] +
112
+ lr_[0 ] * g / sqrt (ms_out - mg_out * mg_out + epsilon_);
113
+ param_[idx] -= mom_out;
114
+ ms_[idx] = ms_out;
115
+ mom_[idx] = mom_out;
116
+ mean_grad_[idx] = mg_out;
117
+ }
118
+
119
+ T *param_;
120
+ T *ms_;
121
+ T *mom_;
122
+ T *mean_grad_;
123
+ const T *lr_;
124
+ T rho_;
125
+ T epsilon_;
126
+ T momentum_;
127
+ GradFunctor grad_functor_;
128
+ };
129
+
27
130
template <typename DeviceContext, typename T>
28
131
class RmspropOpKernel : public framework ::OpKernel<T> {
29
132
public:
30
- void Compute (const framework::ExecutionContext& ctx) const override {
31
- auto * param_out = ctx.Output <Tensor>(" ParamOut" );
32
- auto * moment_out = ctx.Output <Tensor>(" MomentOut" );
33
- auto * mean_square_out = ctx.Output <Tensor>(" MeanSquareOut" );
133
+ void Compute (const framework::ExecutionContext &ctx) const override {
134
+ using Tensor = framework::LoDTensor;
135
+ auto *grad_var = ctx.InputVar (" Grad" );
136
+ auto *param_out = ctx.Output <Tensor>(" ParamOut" );
137
+ auto *moment_out = ctx.Output <Tensor>(" MomentOut" );
138
+ auto *mean_square_out = ctx.Output <Tensor>(" MeanSquareOut" );
34
139
35
- auto grad = ctx.Input <Tensor>(" Grad" );
140
+ auto epsilon = static_cast <T>(ctx.Attr <float >(" epsilon" ));
141
+ auto rho = static_cast <T>(ctx.Attr <float >(" decay" ));
142
+ auto momentum = static_cast <T>(ctx.Attr <float >(" momentum" ));
143
+ bool centered = ctx.Attr <bool >(" centered" );
36
144
37
- param_out->mutable_data <T>(ctx.GetPlace ());
38
- moment_out->mutable_data <T>(ctx.GetPlace ());
39
- mean_square_out->mutable_data <T>(ctx.GetPlace ());
145
+ auto &p_tensor = *ctx.Input <Tensor>(" Param" );
146
+ auto &ms_tensor = *ctx.Input <Tensor>(" MeanSquare" );
147
+ auto &lr_tensor = *ctx.Input <Tensor>(" LearningRate" );
148
+ auto &mom_tensor = *ctx.Input <Tensor>(" Moment" );
40
149
41
- float epsilon = ctx.Attr <float >(" epsilon" );
42
- float rho = ctx.Attr <float >(" decay" );
43
- float momentum = ctx.Attr <float >(" momentum" );
44
- bool centered = ctx.Attr <bool >(" centered" );
150
+ PADDLE_ENFORCE_EQ (&p_tensor, param_out,
151
+ " Param and ParamOut must be the same Tensor" );
152
+ PADDLE_ENFORCE_EQ (&mom_tensor, moment_out,
153
+ " Moment and MomentOut must be the same Tensor" );
154
+ PADDLE_ENFORCE_EQ (&ms_tensor, mean_square_out,
155
+ " MeanSquare and MeanSquareOut must be the same Tensor" );
156
+
157
+ auto &dev_ctx = ctx.template device_context <DeviceContext>();
158
+ size_t limit = static_cast <size_t >(ms_tensor.numel ());
159
+
160
+ if (grad_var->IsType <Tensor>()) {
161
+ auto &grad_tensor = grad_var->Get <Tensor>();
162
+
163
+ if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value) {
164
+ auto &place =
165
+ *ctx.template device_context <DeviceContext>().eigen_device ();
166
+ auto lr_value = lr_tensor.data <T>()[0 ];
167
+
168
+ auto p = EigenVector<T>::Flatten (p_tensor);
169
+ auto ms = EigenVector<T>::Flatten (ms_tensor);
170
+ auto g = EigenVector<T>::Flatten (grad_tensor);
171
+ auto mom = EigenVector<T>::Flatten (mom_tensor);
172
+
173
+ auto p_out = EigenVector<T>::Flatten (*param_out);
174
+ auto mom_out = EigenVector<T>::Flatten (*moment_out);
175
+ auto ms_out = EigenVector<T>::Flatten (*mean_square_out);
176
+
177
+ ms_out.device (place) = rho * ms + (1 - rho) * g * g;
178
+ if (centered) {
179
+ auto &mg_tensor = *ctx.Input <Tensor>(" MeanGrad" );
180
+ auto mg = EigenVector<T>::Flatten (mg_tensor);
181
+ auto *mean_grad_out = ctx.Output <Tensor>(" MeanGradOut" );
182
+ PADDLE_ENFORCE (&mg_tensor, mean_grad_out,
183
+ " MeanGrad and MeanGradOut must be the same Tensor" );
184
+ auto mg_out = EigenVector<T>::Flatten (*mean_grad_out);
185
+
186
+ mg_out.device (place) = rho * mg + (1 - rho) * g;
187
+ mom_out.device (place) =
188
+ momentum * mom +
189
+ lr_value * g / (ms_out - mg_out.square () + epsilon).sqrt ();
190
+ } else {
191
+ mom_out.device (place) =
192
+ momentum * mom + lr_value * g / (ms_out + epsilon).sqrt ();
193
+ }
194
+ p_out.device (place) = p - mom_out;
195
+ } else {
196
+ DenseRmspropGradFunctor<T> grad_func (grad_tensor.data <T>());
197
+ platform::ForRange<DeviceContext> for_range (dev_ctx, limit);
198
+ if (centered) {
199
+ auto &mg_tensor = *ctx.Input <Tensor>(" MeanGrad" );
200
+ auto *mean_grad_out = ctx.Output <Tensor>(" MeanGradOut" );
201
+ PADDLE_ENFORCE (&mg_tensor, mean_grad_out,
202
+ " MeanGrad and MeanGradOut must be the same Tensor" );
203
+ for_range (CenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
204
+ param_out->mutable_data <T>(ctx.GetPlace ()),
205
+ mean_square_out->mutable_data <T>(ctx.GetPlace ()),
206
+ moment_out->mutable_data <T>(ctx.GetPlace ()),
207
+ mean_grad_out->mutable_data <T>(ctx.GetPlace ()),
208
+ lr_tensor.data <T>(), rho, epsilon, momentum, grad_func));
209
+ } else {
210
+ for_range (UncenteredRmspropFunctor<T, DenseRmspropGradFunctor<T>>(
211
+ param_out->mutable_data <T>(ctx.GetPlace ()),
212
+ mean_square_out->mutable_data <T>(ctx.GetPlace ()),
213
+ moment_out->mutable_data <T>(ctx.GetPlace ()), lr_tensor.data <T>(),
214
+ rho, epsilon, momentum, grad_func));
215
+ }
216
+ }
217
+ } else if (grad_var->IsType <framework::SelectedRows>()) {
218
+ auto &grad = grad_var->Get <framework::SelectedRows>();
219
+ auto *merged_grad = const_cast <framework::Scope &>(ctx.scope ())
220
+ .Var ()
221
+ ->GetMutable <framework::SelectedRows>();
222
+
223
+ math::scatter::MergeAdd<DeviceContext, T> merge_func;
224
+ merge_func (dev_ctx, grad, merged_grad);
225
+
226
+ platform::ForRange<DeviceContext> for_range (dev_ctx, limit);
227
+ const int64_t *rows;
228
+ #ifdef PADDLE_WITH_CUDA
229
+ if (platform::is_gpu_place (ctx.GetPlace ())) {
230
+ rows = merged_grad->rows ().CUDAData (ctx.GetPlace ());
231
+ } else {
232
+ #endif
233
+ rows = merged_grad->rows ().data ();
234
+ #ifdef PADDLE_WITH_CUDA
235
+ }
236
+ #endif
237
+ auto &merged_tensor = merged_grad->value ();
238
+ int64_t row_count = merged_grad->rows ().size ();
239
+ int64_t row_numel = merged_tensor.numel () / row_count;
240
+ SparseRmspropGradFunctor<T> grad_func (merged_tensor.data <T>(), rows,
241
+ row_numel, row_count);
45
242
46
- auto p = EigenVector<T>::Flatten (*ctx.Input <Tensor>(" Param" ));
47
- auto ms = EigenVector<T>::Flatten (*ctx.Input <Tensor>(" MeanSquare" ));
48
- auto lr = EigenVector<T>::Flatten (*ctx.Input <Tensor>(" LearningRate" ));
49
- auto g = EigenVector<T>::Flatten (*grad);
50
- auto mom = EigenVector<T>::Flatten (*ctx.Input <Tensor>(" Moment" ));
51
-
52
- auto p_out = EigenVector<T>::Flatten (*param_out);
53
- auto mom_out = EigenVector<T>::Flatten (*moment_out);
54
- auto ms_out = EigenVector<T>::Flatten (*mean_square_out);
55
- auto & place = *ctx.template device_context <DeviceContext>().eigen_device ();
56
-
57
- Eigen::DSizes<int , 1 > grad_dsize (static_cast <int >(grad->numel ()));
58
-
59
- ms_out.device (place) = rho * ms + (1 - rho) * g * g;
60
- if (centered) {
61
- auto mg = EigenVector<T>::Flatten (*ctx.Input <Tensor>(" MeanGrad" ));
62
- auto * mean_grad_out = ctx.Output <Tensor>(" MeanGradOut" );
63
- mean_grad_out->mutable_data <T>(ctx.GetPlace ());
64
- auto mg_out = EigenVector<T>::Flatten (*mean_grad_out);
65
-
66
- mg_out.device (place) = rho * mg + (1 - rho) * g;
67
- mom_out.device (place) = momentum * mom +
68
- lr.broadcast (grad_dsize) * g /
69
- (ms_out - mg_out.square () + epsilon).sqrt ();
243
+ if (centered) {
244
+ auto &mg_tensor = *ctx.Input <Tensor>(" MeanGrad" );
245
+ auto *mean_grad_out = ctx.Output <Tensor>(" MeanGradOut" );
246
+ PADDLE_ENFORCE (&mg_tensor, mean_grad_out,
247
+ " MeanGrad and MeanGradOut must be the same Tensor" );
248
+ for_range (CenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
249
+ param_out->mutable_data <T>(ctx.GetPlace ()),
250
+ mean_square_out->mutable_data <T>(ctx.GetPlace ()),
251
+ moment_out->mutable_data <T>(ctx.GetPlace ()),
252
+ mean_grad_out->mutable_data <T>(ctx.GetPlace ()), lr_tensor.data <T>(),
253
+ rho, epsilon, momentum, grad_func));
254
+ } else {
255
+ for_range (UncenteredRmspropFunctor<T, SparseRmspropGradFunctor<T>>(
256
+ param_out->mutable_data <T>(ctx.GetPlace ()),
257
+ mean_square_out->mutable_data <T>(ctx.GetPlace ()),
258
+ moment_out->mutable_data <T>(ctx.GetPlace ()), lr_tensor.data <T>(),
259
+ rho, epsilon, momentum, grad_func));
260
+ }
70
261
} else {
71
- mom_out.device (place) =
72
- momentum * mom +
73
- lr.broadcast (grad_dsize) * g / (ms_out + epsilon).sqrt ();
262
+ PADDLE_THROW (" RMSProp only supports LoDTensor or SelectedRows gradient" );
74
263
}
75
- p_out.device (place) = p - mom_out;
76
264
}
77
265
};
78
266
0 commit comments