@@ -174,12 +174,13 @@ struct SparseAdamFunctor {
174
174
175
175
const int64_t * rows_;
176
176
int64_t row_numel_;
177
+ int64_t row_count_;
177
178
178
179
SparseAdamFunctor (T beta1, T beta2, T epsilon, const T* beta1_pow,
179
180
const T* beta2_pow, const T* mom1, T* mom1_out,
180
181
const T* mom2, T* mom2_out, const T* lr, const T* grad,
181
182
const T* param, T* param_out, const int64_t * rows,
182
- int64_t row_numel)
183
+ int64_t row_numel, int64_t row_count )
183
184
: beta1_(beta1),
184
185
beta2_ (beta2),
185
186
epsilon_(epsilon),
@@ -194,28 +195,47 @@ struct SparseAdamFunctor {
194
195
param_(param),
195
196
param_out_(param_out),
196
197
rows_(rows),
197
- row_numel_(row_numel) {}
198
+ row_numel_(row_numel),
199
+ row_count_(row_count) {}
200
+
201
+ inline HOSTDEVICE int64_t BinarySearchInRows (int64_t row) const {
202
+ int64_t beg = 0 , end = row_count_ - 1 ;
203
+ while (beg <= end) {
204
+ auto mid = ((beg + end) >> 1 );
205
+ if (rows_[mid] == row)
206
+ return mid;
207
+ else if (rows_[mid] < row)
208
+ beg = mid + 1 ;
209
+ else
210
+ end = mid - 1 ;
211
+ }
212
+ return -1 ;
213
+ }
198
214
199
215
inline HOSTDEVICE void operator ()(size_t i) const {
216
+ int64_t row = i / row_numel_;
217
+ auto row_idx = BinarySearchInRows (row);
218
+ T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0 ;
219
+
220
+ // The following code is the same as dense
221
+ T mom1 = moment1_[i];
222
+ T mom2 = moment2_[i];
223
+ T lr = *lr_;
200
224
T beta1_pow = *beta1_pow_;
201
225
T beta2_pow = *beta2_pow_;
202
- for (int64_t j = 0 ; j < row_numel_; ++j) {
203
- T g = grad_[i * row_numel_ + j];
204
- T mom1 = moment1_[rows_[i] * row_numel_ + j];
205
- T mom2 = moment2_[rows_[i] * row_numel_ + j];
206
- T lr = *lr_;
207
- T p = param_[rows_[i] * row_numel_ + j];
208
-
209
- lr *= sqrt (1 - beta2_pow) / (1 - beta1_pow);
210
-
211
- mom1 = beta1_ * mom1 + (1 - beta1_) * g;
212
- mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
213
- p -= lr * (mom1 / (sqrt (mom2) + epsilon_));
214
-
215
- moment1_out_[rows_[i] * row_numel_ + j] = mom1;
216
- moment2_out_[rows_[i] * row_numel_ + j] = mom2;
217
- param_out_[rows_[i] * row_numel_ + j] = p;
218
- } // for col id
226
+ T p = param_[i];
227
+
228
+ // Calculation
229
+ lr *= sqrt (1 - beta2_pow) / (1 - beta1_pow);
230
+
231
+ mom1 = beta1_ * mom1 + (1 - beta1_) * g;
232
+ mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
233
+ p -= lr * (mom1 / (sqrt (mom2) + epsilon_));
234
+
235
+ // Write back to global memory
236
+ moment1_out_[i] = mom1;
237
+ moment2_out_[i] = mom2;
238
+ param_out_[i] = p;
219
239
}
220
240
};
221
241
@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
287
307
return ;
288
308
}
289
309
// merge duplicated rows if any.
310
+ // The rows of grad_merge have been sorted inside MergeAdd functor
290
311
scatter::MergeAdd<DeviceContext, T> merge_func;
291
- auto grad_merge =
292
- merge_func (ctx.template device_context <DeviceContext>(), grad);
312
+ auto & grad_merge = *(ctx.scope ()
313
+ .NewScope ()
314
+ .Var (" sparse_adam_grad_merge" )
315
+ ->GetMutable <framework::SelectedRows>());
316
+ merge_func (ctx.template device_context <DeviceContext>(), grad,
317
+ &grad_merge);
293
318
auto & grad_tensor = grad_merge.value ();
294
319
const T* grad_data = grad_tensor.template data <T>();
295
320
int64_t * rows = nullptr ;
@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
314
339
mom2.template data <T>(),
315
340
mom2_out.template mutable_data <T>(ctx.GetPlace ()),
316
341
lr.template data <T>(), grad_data, param.template data <T>(),
317
- param_out.template mutable_data <T>(ctx.GetPlace ()), rows, row_numel);
342
+ param_out.template mutable_data <T>(ctx.GetPlace ()), rows, row_numel,
343
+ grad_merge.rows ().size ());
318
344
platform::ForRange<DeviceContext> for_range (
319
345
static_cast <const DeviceContext&>(ctx.device_context ()),
320
- grad_merge. rows (). size ());
346
+ param. numel ());
321
347
for_range (functor);
322
348
} else {
323
349
PADDLE_THROW (" Variable type not supported by adam_op" );
0 commit comments