Skip to content

Commit b6f61fa

Browse files
committed
fix adam
1 parent 2d89849 commit b6f61fa

File tree

4 files changed

+68
-25
lines changed

4 files changed

+68
-25
lines changed

paddle/fluid/operators/adam_op.h

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,13 @@ struct SparseAdamFunctor {
174174

175175
const int64_t* rows_;
176176
int64_t row_numel_;
177+
int64_t row_count_;
177178

178179
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
179180
const T* beta2_pow, const T* mom1, T* mom1_out,
180181
const T* mom2, T* mom2_out, const T* lr, const T* grad,
181182
const T* param, T* param_out, const int64_t* rows,
182-
int64_t row_numel)
183+
int64_t row_numel, int64_t row_count)
183184
: beta1_(beta1),
184185
beta2_(beta2),
185186
epsilon_(epsilon),
@@ -194,28 +195,47 @@ struct SparseAdamFunctor {
194195
param_(param),
195196
param_out_(param_out),
196197
rows_(rows),
197-
row_numel_(row_numel) {}
198+
row_numel_(row_numel),
199+
row_count_(row_count) {}
200+
201+
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
202+
int64_t beg = 0, end = row_count_ - 1;
203+
while (beg <= end) {
204+
auto mid = ((beg + end) >> 1);
205+
if (rows_[mid] == row)
206+
return mid;
207+
else if (rows_[mid] < row)
208+
beg = mid + 1;
209+
else
210+
end = mid - 1;
211+
}
212+
return -1;
213+
}
198214

199215
inline HOSTDEVICE void operator()(size_t i) const {
216+
int64_t row = i / row_numel_;
217+
auto row_idx = BinarySearchInRows(row);
218+
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
219+
220+
// The following code is the same as dense
221+
T mom1 = moment1_[i];
222+
T mom2 = moment2_[i];
223+
T lr = *lr_;
200224
T beta1_pow = *beta1_pow_;
201225
T beta2_pow = *beta2_pow_;
202-
for (int64_t j = 0; j < row_numel_; ++j) {
203-
T g = grad_[i * row_numel_ + j];
204-
T mom1 = moment1_[rows_[i] * row_numel_ + j];
205-
T mom2 = moment2_[rows_[i] * row_numel_ + j];
206-
T lr = *lr_;
207-
T p = param_[rows_[i] * row_numel_ + j];
208-
209-
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
210-
211-
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
212-
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
213-
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
214-
215-
moment1_out_[rows_[i] * row_numel_ + j] = mom1;
216-
moment2_out_[rows_[i] * row_numel_ + j] = mom2;
217-
param_out_[rows_[i] * row_numel_ + j] = p;
218-
} // for col id
226+
T p = param_[i];
227+
228+
// Calculation
229+
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
230+
231+
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
232+
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
233+
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
234+
235+
// Write back to global memory
236+
moment1_out_[i] = mom1;
237+
moment2_out_[i] = mom2;
238+
param_out_[i] = p;
219239
}
220240
};
221241

@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
287307
return;
288308
}
289309
// merge duplicated rows if any.
310+
// The rows of grad_merge have been sorted inside MergeAdd functor
290311
scatter::MergeAdd<DeviceContext, T> merge_func;
291-
auto grad_merge =
292-
merge_func(ctx.template device_context<DeviceContext>(), grad);
312+
auto& grad_merge = *(ctx.scope()
313+
.NewScope()
314+
.Var("sparse_adam_grad_merge")
315+
->GetMutable<framework::SelectedRows>());
316+
merge_func(ctx.template device_context<DeviceContext>(), grad,
317+
&grad_merge);
293318
auto& grad_tensor = grad_merge.value();
294319
const T* grad_data = grad_tensor.template data<T>();
295320
int64_t* rows = nullptr;
@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
314339
mom2.template data<T>(),
315340
mom2_out.template mutable_data<T>(ctx.GetPlace()),
316341
lr.template data<T>(), grad_data, param.template data<T>(),
317-
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel);
342+
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
343+
grad_merge.rows().size());
318344
platform::ForRange<DeviceContext> for_range(
319345
static_cast<const DeviceContext&>(ctx.device_context()),
320-
grad_merge.rows().size());
346+
param.numel());
321347
for_range(functor);
322348
} else {
323349
PADDLE_THROW("Variable type not supported by adam_op");

paddle/fluid/operators/math/selected_rows_functor.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
199199
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
200200
const framework::SelectedRows& input) {
201201
framework::SelectedRows out;
202+
(*this)(context, input, &out);
203+
return out;
204+
}
205+
206+
void operator()(const platform::CPUDeviceContext& context,
207+
const framework::SelectedRows& input,
208+
framework::SelectedRows* output) {
209+
framework::SelectedRows& out = *output;
202210
auto input_rows = input.rows();
203211
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
204212
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
223231
out_data[out_i * input_width + j] += input_data[i * input_width + j];
224232
}
225233
}
226-
return out;
227234
}
228235
};
229236

paddle/fluid/operators/math/selected_rows_functor.cu

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
262262
framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
263263
const framework::SelectedRows& input) {
264264
framework::SelectedRows out;
265+
(*this)(context, input, &out);
266+
return out;
267+
}
268+
269+
void operator()(const platform::CUDADeviceContext& context,
270+
const framework::SelectedRows& input,
271+
framework::SelectedRows* output) {
272+
framework::SelectedRows& out = *output;
265273
framework::Vector<int64_t> input_rows(input.rows());
266274
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
267275
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
@@ -292,7 +300,6 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
292300
input_data, input_rows.CUDAData(context.GetPlace()), out_data,
293301
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
294302
out.rows().size(), input_width);
295-
return out;
296303
}
297304
};
298305

paddle/fluid/operators/math/selected_rows_functor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ struct MergeAdd {
6565
// the input SelectedRows object.
6666
framework::SelectedRows operator()(const DeviceContext& context,
6767
const framework::SelectedRows& input);
68+
void operator()(const DeviceContext& context,
69+
const framework::SelectedRows& input,
70+
framework::SelectedRows* output);
6871
};
6972

7073
template <typename DeviceContext, typename T>

0 commit comments

Comments
 (0)