Skip to content

Commit 7f1e312

Browse files
authored
Merge pull request #13456 from sneaxiy/refine_sparse_adam
Fix sparse Adam and Gradient clip of SelectedRows
2 parents b758875 + a29b422 commit 7f1e312

File tree

5 files changed

+103
-42
lines changed

5 files changed

+103
-42
lines changed

paddle/fluid/operators/adam_op.h

Lines changed: 49 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -174,12 +174,13 @@ struct SparseAdamFunctor {
174174

175175
const int64_t* rows_;
176176
int64_t row_numel_;
177+
int64_t row_count_;
177178

178179
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
179180
const T* beta2_pow, const T* mom1, T* mom1_out,
180181
const T* mom2, T* mom2_out, const T* lr, const T* grad,
181182
const T* param, T* param_out, const int64_t* rows,
182-
int64_t row_numel)
183+
int64_t row_numel, int64_t row_count)
183184
: beta1_(beta1),
184185
beta2_(beta2),
185186
epsilon_(epsilon),
@@ -194,28 +195,47 @@ struct SparseAdamFunctor {
194195
param_(param),
195196
param_out_(param_out),
196197
rows_(rows),
197-
row_numel_(row_numel) {}
198+
row_numel_(row_numel),
199+
row_count_(row_count) {}
200+
201+
inline HOSTDEVICE int64_t BinarySearchInRows(int64_t row) const {
202+
int64_t beg = 0, end = row_count_ - 1;
203+
while (beg <= end) {
204+
auto mid = ((beg + end) >> 1);
205+
if (rows_[mid] == row)
206+
return mid;
207+
else if (rows_[mid] < row)
208+
beg = mid + 1;
209+
else
210+
end = mid - 1;
211+
}
212+
return -1;
213+
}
198214

199215
inline HOSTDEVICE void operator()(size_t i) const {
216+
int64_t row = i / row_numel_;
217+
auto row_idx = BinarySearchInRows(row);
218+
T g = row_idx >= 0 ? grad_[row_idx * row_numel_ + i % row_numel_] : 0;
219+
220+
// The following code is the same as dense
221+
T mom1 = moment1_[i];
222+
T mom2 = moment2_[i];
223+
T lr = *lr_;
200224
T beta1_pow = *beta1_pow_;
201225
T beta2_pow = *beta2_pow_;
202-
for (int64_t j = 0; j < row_numel_; ++j) {
203-
T g = grad_[i * row_numel_ + j];
204-
T mom1 = moment1_[rows_[i] * row_numel_ + j];
205-
T mom2 = moment2_[rows_[i] * row_numel_ + j];
206-
T lr = *lr_;
207-
T p = param_[rows_[i] * row_numel_ + j];
208-
209-
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
210-
211-
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
212-
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
213-
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
214-
215-
moment1_out_[rows_[i] * row_numel_ + j] = mom1;
216-
moment2_out_[rows_[i] * row_numel_ + j] = mom2;
217-
param_out_[rows_[i] * row_numel_ + j] = p;
218-
} // for col id
226+
T p = param_[i];
227+
228+
// Calculation
229+
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
230+
231+
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
232+
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
233+
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
234+
235+
// Write back to global memory
236+
moment1_out_[i] = mom1;
237+
moment2_out_[i] = mom2;
238+
param_out_[i] = p;
219239
}
220240
};
221241

@@ -287,9 +307,14 @@ class AdamOpKernel : public framework::OpKernel<T> {
287307
return;
288308
}
289309
// merge duplicated rows if any.
310+
// The rows of grad_merge have been sorted inside MergeAdd functor
290311
scatter::MergeAdd<DeviceContext, T> merge_func;
291-
auto grad_merge =
292-
merge_func(ctx.template device_context<DeviceContext>(), grad);
312+
auto& grad_merge = *(ctx.scope()
313+
.NewScope()
314+
.Var("sparse_adam_grad_merge")
315+
->GetMutable<framework::SelectedRows>());
316+
merge_func(ctx.template device_context<DeviceContext>(), grad,
317+
&grad_merge);
293318
auto& grad_tensor = grad_merge.value();
294319
const T* grad_data = grad_tensor.template data<T>();
295320
int64_t* rows = nullptr;
@@ -314,10 +339,11 @@ class AdamOpKernel : public framework::OpKernel<T> {
314339
mom2.template data<T>(),
315340
mom2_out.template mutable_data<T>(ctx.GetPlace()),
316341
lr.template data<T>(), grad_data, param.template data<T>(),
317-
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel);
342+
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel,
343+
grad_merge.rows().size());
318344
platform::ForRange<DeviceContext> for_range(
319345
static_cast<const DeviceContext&>(ctx.device_context()),
320-
grad_merge.rows().size());
346+
param.numel());
321347
for_range(functor);
322348
} else {
323349
PADDLE_THROW("Variable type not supported by adam_op");

paddle/fluid/operators/clip_op.h

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/selected_rows_functor.h"
1920
#include "paddle/fluid/platform/transform.h"
2021

2122
namespace paddle {
@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
6162
void Compute(const framework::ExecutionContext& context) const override {
6263
auto max = context.Attr<T>("max");
6364
auto min = context.Attr<T>("min");
64-
auto* x = context.Input<Tensor>("X");
65-
auto* out = context.Output<Tensor>("Out");
66-
T* out_data = out->mutable_data<T>(context.GetPlace());
67-
const T* x_data = x->data<T>();
68-
int64_t numel = x->numel();
69-
Transform<DeviceContext> trans;
70-
trans(context.template device_context<DeviceContext>(), x_data,
71-
x_data + numel, out_data, ClipFunctor<T>(min, max));
65+
auto* x_var = context.InputVar("X");
66+
if (x_var->IsType<framework::LoDTensor>()) {
67+
auto* x = context.Input<framework::LoDTensor>("X");
68+
auto* out = context.Output<framework::LoDTensor>("Out");
69+
T* out_data = out->mutable_data<T>(context.GetPlace());
70+
const T* x_data = x->data<T>();
71+
int64_t numel = x->numel();
72+
Transform<DeviceContext> trans;
73+
trans(context.template device_context<DeviceContext>(), x_data,
74+
x_data + numel, out_data, ClipFunctor<T>(min, max));
75+
} else if (x_var->IsType<framework::SelectedRows>()) {
76+
auto* x = context.Input<framework::SelectedRows>("X");
77+
auto* out = context.Output<framework::SelectedRows>("Out");
78+
PADDLE_ENFORCE_NE(x, out,
79+
"Inplace clip is not allowed when x is SelectedRows");
80+
math::scatter::MergeAdd<DeviceContext, T> merge_func;
81+
merge_func(context.template device_context<DeviceContext>(), *x, out);
82+
auto* out_tensor = out->mutable_value();
83+
auto* out_data = out_tensor->data<T>();
84+
int64_t numel = out_tensor->numel();
85+
Transform<DeviceContext> trans;
86+
trans(context.template device_context<DeviceContext>(), out_data,
87+
out_data + numel, out_data, ClipFunctor<T>(min, max));
88+
} else {
89+
PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows");
90+
}
7291
}
7392
};
7493

@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
7897
void Compute(const framework::ExecutionContext& context) const override {
7998
auto max = context.Attr<T>("max");
8099
auto min = context.Attr<T>("min");
81-
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
82-
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
100+
auto* d_out =
101+
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
102+
auto* d_x =
103+
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
83104
if (d_x != nullptr) {
84-
auto* x = context.Input<Tensor>("X");
105+
auto* x = context.Input<framework::LoDTensor>("X");
85106
int64_t numel = d_out->numel();
86107
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
87108
const T* d_out_data = d_out->data<T>();

paddle/fluid/operators/math/selected_rows_functor.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,14 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
199199
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
200200
const framework::SelectedRows& input) {
201201
framework::SelectedRows out;
202+
(*this)(context, input, &out);
203+
return out;
204+
}
205+
206+
void operator()(const platform::CPUDeviceContext& context,
207+
const framework::SelectedRows& input,
208+
framework::SelectedRows* output) {
209+
framework::SelectedRows& out = *output;
202210
auto input_rows = input.rows();
203211
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
204212
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
@@ -223,7 +231,6 @@ struct MergeAdd<platform::CPUDeviceContext, T> {
223231
out_data[out_i * input_width + j] += input_data[i * input_width + j];
224232
}
225233
}
226-
return out;
227234
}
228235
};
229236

paddle/fluid/operators/math/selected_rows_functor.cu

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ template <typename T, int block_size>
234234
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
235235
T* out, const int64_t* out_rows,
236236
size_t out_rows_size, int64_t row_numel) {
237-
const int ty = blockIdx.y;
237+
const int ty = blockIdx.x;
238238
int tid = threadIdx.x;
239239
__shared__ size_t out_idx;
240240

@@ -260,6 +260,14 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
260260
framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
261261
const framework::SelectedRows& input) {
262262
framework::SelectedRows out;
263+
(*this)(context, input, &out);
264+
return out;
265+
}
266+
267+
void operator()(const platform::CUDADeviceContext& context,
268+
const framework::SelectedRows& input,
269+
framework::SelectedRows* output) {
270+
framework::SelectedRows& out = *output;
263271
framework::Vector<int64_t> input_rows(input.rows());
264272
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
265273
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
@@ -281,16 +289,12 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
281289

282290
const int block_size = 256;
283291
dim3 threads(block_size, 1);
284-
dim3 grid1(1, input_rows.size());
292+
dim3 grid1(input_rows.size(), 1);
285293

286-
MergeAddKernel<
287-
T, 256><<<grid1, threads, 0,
288-
reinterpret_cast<const platform::CUDADeviceContext&>(context)
289-
.stream()>>>(
294+
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
290295
input_data, input_rows.CUDAData(context.GetPlace()), out_data,
291296
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
292297
out.rows().size(), input_width);
293-
return out;
294298
}
295299
};
296300

paddle/fluid/operators/math/selected_rows_functor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ struct MergeAdd {
6565
// the input SelectedRows object.
6666
framework::SelectedRows operator()(const DeviceContext& context,
6767
const framework::SelectedRows& input);
68+
void operator()(const DeviceContext& context,
69+
const framework::SelectedRows& input,
70+
framework::SelectedRows* output);
6871
};
6972

7073
template <typename DeviceContext, typename T>

0 commit comments

Comments
 (0)