Skip to content

Commit a29b422

Browse files
committed
fix sparse gradient clip
1 parent b6f61fa commit a29b422

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

paddle/fluid/operators/clip_op.h

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include "paddle/fluid/framework/eigen.h"
1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/operators/math/selected_rows_functor.h"
1920
#include "paddle/fluid/platform/transform.h"
2021

2122
namespace paddle {
@@ -61,14 +62,32 @@ class ClipKernel : public framework::OpKernel<T> {
6162
void Compute(const framework::ExecutionContext& context) const override {
6263
auto max = context.Attr<T>("max");
6364
auto min = context.Attr<T>("min");
64-
auto* x = context.Input<Tensor>("X");
65-
auto* out = context.Output<Tensor>("Out");
66-
T* out_data = out->mutable_data<T>(context.GetPlace());
67-
const T* x_data = x->data<T>();
68-
int64_t numel = x->numel();
69-
Transform<DeviceContext> trans;
70-
trans(context.template device_context<DeviceContext>(), x_data,
71-
x_data + numel, out_data, ClipFunctor<T>(min, max));
65+
auto* x_var = context.InputVar("X");
66+
if (x_var->IsType<framework::LoDTensor>()) {
67+
auto* x = context.Input<framework::LoDTensor>("X");
68+
auto* out = context.Output<framework::LoDTensor>("Out");
69+
T* out_data = out->mutable_data<T>(context.GetPlace());
70+
const T* x_data = x->data<T>();
71+
int64_t numel = x->numel();
72+
Transform<DeviceContext> trans;
73+
trans(context.template device_context<DeviceContext>(), x_data,
74+
x_data + numel, out_data, ClipFunctor<T>(min, max));
75+
} else if (x_var->IsType<framework::SelectedRows>()) {
76+
auto* x = context.Input<framework::SelectedRows>("X");
77+
auto* out = context.Output<framework::SelectedRows>("Out");
78+
PADDLE_ENFORCE_NE(x, out,
79+
"Inplace clip is not allowed when x is SelectedRows");
80+
math::scatter::MergeAdd<DeviceContext, T> merge_func;
81+
merge_func(context.template device_context<DeviceContext>(), *x, out);
82+
auto* out_tensor = out->mutable_value();
83+
auto* out_data = out_tensor->data<T>();
84+
int64_t numel = out_tensor->numel();
85+
Transform<DeviceContext> trans;
86+
trans(context.template device_context<DeviceContext>(), out_data,
87+
out_data + numel, out_data, ClipFunctor<T>(min, max));
88+
} else {
89+
PADDLE_THROW("ClipOp only supports LoDTensor and SelectedRows");
90+
}
7291
}
7392
};
7493

@@ -78,10 +97,12 @@ class ClipGradKernel : public framework::OpKernel<T> {
7897
void Compute(const framework::ExecutionContext& context) const override {
7998
auto max = context.Attr<T>("max");
8099
auto min = context.Attr<T>("min");
81-
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
82-
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
100+
auto* d_out =
101+
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
102+
auto* d_x =
103+
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
83104
if (d_x != nullptr) {
84-
auto* x = context.Input<Tensor>("X");
105+
auto* x = context.Input<framework::LoDTensor>("X");
85106
int64_t numel = d_out->numel();
86107
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
87108
const T* d_out_data = d_out->data<T>();

paddle/fluid/operators/math/selected_rows_functor.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ template <typename T, int block_size>
236236
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
237237
T* out, const int64_t* out_rows,
238238
size_t out_rows_size, int64_t row_numel) {
239-
const int ty = blockIdx.y;
239+
const int ty = blockIdx.x;
240240
int tid = threadIdx.x;
241241
__shared__ size_t out_idx;
242242

@@ -291,12 +291,9 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
291291

292292
const int block_size = 256;
293293
dim3 threads(block_size, 1);
294-
dim3 grid1(1, input_rows.size());
294+
dim3 grid1(input_rows.size(), 1);
295295

296-
MergeAddKernel<
297-
T, 256><<<grid1, threads, 0,
298-
reinterpret_cast<const platform::CUDADeviceContext&>(context)
299-
.stream()>>>(
296+
MergeAddKernel<T, 256><<<grid1, threads, 0, context.stream()>>>(
300297
input_data, input_rows.CUDAData(context.GetPlace()), out_data,
301298
out.mutable_rows()->CUDAMutableData(context.GetPlace()),
302299
out.rows().size(), input_width);

0 commit comments

Comments
 (0)