Skip to content

Commit 008f40c

Browse files
authored
support sparse output for lookup table grad op (#5145)
* add sparse support for sum op * typo fix * fix gpu build error * fix unittest error * typo fix * infer var type and shape in op_test * follow comments * fix build error * bypass some unittests depend on NetOp * support sparse output for lookup table grad op * refine codes * fix gpu build error * fix lookup table grad gpu kernel * fix ci * fix ci * fix ci * fix bug in lookup_table_grad op * fix bug in test_word2vec * register double kernel for some operators * set is_sparse=True in test_word2vec * fix lookup table grad op CUDA kernel bug * disable test_modified_huber_loss_op temporarily * disable test_lstm_unit_op temporarily
1 parent 3ecad8a commit 008f40c

23 files changed

+218
-114
lines changed

paddle/operators/cross_entropy_op.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace {
2121

2222
template <typename T>
2323
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
24-
const int* label, const int N,
24+
const int64_t* label, const int N,
2525
const int D) {
2626
// TOOD(qingqing) define CUDA_1D_KERNEL_LOOP macro in a common file.
2727
// CUDA_1D_KERNEL_LOOP(i, N) {
@@ -77,8 +77,8 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
7777
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
7878
const T* x_data = x->data<T>();
7979

80-
int batch_size = x->dims()[0];
81-
int class_num = x->dims()[1];
80+
int64_t batch_size = x->dims()[0];
81+
int64_t class_num = x->dims()[1];
8282

8383
int block = 512;
8484
int grid = (batch_size * class_num + block - 1) / block;
@@ -93,7 +93,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
9393
} else {
9494
math::SetConstant<platform::GPUPlace, T> functor;
9595
functor(ctx.device_context(), dx, 0);
96-
auto* label_data = label->data<int>();
96+
auto* label_data = label->data<int64_t>();
9797
grid = (batch_size + block - 1) / block;
9898
CrossEntropyGradientKernel<T><<<
9999
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(

paddle/operators/cross_entropy_op.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,28 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
5454
Tensor* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
5555
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
5656

57-
int class_num = x->dims()[1];
57+
int64_t class_num = x->dims()[1];
5858
if (ctx.Attr<bool>("soft_label")) {
5959
auto x_mat = EigenMatrix<T>::From(*x);
6060
auto dy_mat = EigenMatrix<T>::From(*dy);
6161
auto lbl_mat = EigenMatrix<T>::From(*label);
6262
auto dx_mat = EigenMatrix<T>::From(*dx);
6363

6464
dx_mat.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
65-
-(lbl_mat * dy_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) /
66-
x_mat);
65+
-(lbl_mat *
66+
dy_mat.broadcast(Eigen::DSizes<int64_t, 2>(1, class_num)) / x_mat);
6767
} else {
68-
int batch_size = x->dims()[0];
68+
int64_t batch_size = x->dims()[0];
6969
const T* dy_data = dy->data<T>();
7070
const T* x_data = x->data<T>();
71-
const int* label_data = label->data<int>();
71+
const int64_t* label_data = label->data<int64_t>();
7272

7373
math::SetConstant<platform::CPUPlace, T> functor;
7474
functor(ctx.device_context(), dx, 0);
7575

76-
for (int i = 0; i < batch_size; ++i) {
76+
for (int64_t i = 0; i < batch_size; ++i) {
7777
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
78-
int index = i * class_num + label_data[i];
78+
int64_t index = i * class_num + label_data[i];
7979
dx_data[index] = -dy_data[i] / x_data[index];
8080
}
8181
}

paddle/operators/feed_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class FeedOp : public framework::OperatorBase {
4141

4242
auto col = Attr<int>("col");
4343

44-
VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var"
44+
VLOG(3) << "Feed Var " << feed_var_name << "'s " << col << " column to var "
4545
<< out_name;
4646

4747
auto &feed_list = feed_var->Get<framework::FeedFetchList>();

paddle/operators/lookup_table_op.cc

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
limitations under the License. */
1414

1515
#include "paddle/operators/lookup_table_op.h"
16+
#include "paddle/framework/var_type_inference.h"
1617

1718
namespace paddle {
1819
namespace operators {
@@ -60,6 +61,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
6061
"Ids must be a column vector with rank = 2."
6162
"The 2nd dimension size must be 1");
6263
AddOutput("Out", "The lookup results, which have the same type with W.");
64+
AddAttr<bool>("is_sparse", "Sparse update").SetDefault(false);
6365
AddComment(R"DOC(
6466
This operator is used to perform lookups on the parameter W,
6567
then concatenated into a dense tensor.
@@ -70,6 +72,15 @@ or not. And the output only shares the LoD with input `Ids`.
7072
}
7173
};
7274

75+
class LookupTableOpGradDescMaker
76+
: public framework::DefaultGradOpDescMaker<true> {
77+
using ::paddle::framework::DefaultGradOpDescMaker<
78+
true>::DefaultGradOpDescMaker;
79+
80+
protected:
81+
virtual std::string GradOpType() const { return "lookup_table_grad"; }
82+
};
83+
7384
class LookupTableOpGrad : public framework::OperatorWithKernel {
7485
public:
7586
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -86,12 +97,35 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
8697
}
8798
};
8899

100+
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
101+
public:
102+
void operator()(const framework::OpDescBind& op_desc,
103+
framework::BlockDescBind* block) const override {
104+
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
105+
auto attr = op_desc.GetAttr("is_sparse");
106+
bool is_sparse = boost::get<bool>(attr);
107+
if (is_sparse) {
108+
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
109+
<< " is set to SelectedRows";
110+
block->Var(out_var_name)->SetType(framework::VarDesc::SELECTED_ROWS);
111+
} else {
112+
VLOG(3) << "lookup_table_grad op " << framework::GradVarName("W")
113+
<< " is set to LoDTensor";
114+
block->Var(out_var_name)->SetType(framework::VarDesc::LOD_TENSOR);
115+
}
116+
}
117+
};
118+
89119
} // namespace operators
90120
} // namespace paddle
91121

92122
namespace ops = paddle::operators;
93-
REGISTER_OP(lookup_table, ops::LookupTableOp, ops::LookupTableOpMaker,
94-
lookup_table_grad, ops::LookupTableOpGrad);
95-
96-
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>);
97-
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>);
123+
REGISTER_OPERATOR(lookup_table, ops::LookupTableOp,
124+
ops::LookupTableOpGradDescMaker, ops::LookupTableOpMaker);
125+
REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
126+
ops::LookupTableOpGradVarTypeInference);
127+
128+
REGISTER_OP_CPU_KERNEL(lookup_table, ops::LookupTableKernel<float>,
129+
ops::LookupTableKernel<double>);
130+
REGISTER_OP_CPU_KERNEL(lookup_table_grad, ops::LookupTableGradKernel<float>,
131+
ops::LookupTableGradKernel<double>);

paddle/operators/lookup_table_op.cu

Lines changed: 67 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2-
32
Licensed under the Apache License, Version 2.0 (the "License");
43
you may not use this file except in compliance with the License.
54
You may obtain a copy of the License at
6-
75
http://www.apache.org/licenses/LICENSE-2.0
8-
96
Unless required by applicable law or agreed to in writing, software
107
distributed under the License is distributed on an "AS IS" BASIS,
118
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -14,22 +11,21 @@
1411

1512
#include "paddle/framework/eigen.h"
1613
#include "paddle/framework/op_registry.h"
14+
#include "paddle/operators/lookup_table_op.h"
1715
#include "paddle/platform/assert.h"
1816
#include "paddle/platform/cuda_helper.h"
1917

2018
namespace paddle {
2119
namespace operators {
2220

23-
using Tensor = framework::Tensor;
24-
2521
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
26-
__global__ void LookupTable(T* output, const T* table, const int32_t* ids,
27-
const int N, const int K, const int D) {
22+
__global__ void LookupTable(T* output, const T* table, const int64_t* ids,
23+
const int64_t N, const int64_t K, const int64_t D) {
2824
int idx = threadIdx.x;
2925
int idy = blockIdx.x + threadIdx.y * GridDimX;
3026

3127
while (idy < K) {
32-
int id = ids[idy];
28+
int64_t id = ids[idy];
3329
PADDLE_ASSERT(id >= 0);
3430
PADDLE_ASSERT(id < N);
3531
T* out = output + idy * D;
@@ -42,8 +38,9 @@ __global__ void LookupTable(T* output, const T* table, const int32_t* ids,
4238
}
4339

4440
template <typename T, int BlockDimX, int BlockDimY, int GridDimX>
45-
__global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
46-
const int N, const int K, const int D) {
41+
__global__ void LookupTableGrad(T* table, const T* output, const int64_t* ids,
42+
const int64_t N, const int64_t K,
43+
const int64_t D) {
4744
int idx = threadIdx.x;
4845
int idy = blockIdx.x + threadIdx.y * GridDimX;
4946

@@ -71,7 +68,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
7168
size_t N = table_t->dims()[0];
7269
size_t D = table_t->dims()[1];
7370
size_t K = ids_t->numel();
74-
auto ids = ids_t->data<int32_t>();
71+
auto ids = ids_t->data<int64_t>();
7572
auto table = table_t->data<T>();
7673
auto output = output_t->mutable_data<T>(context.GetPlace());
7774

@@ -88,34 +85,71 @@ template <typename T>
8885
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
8986
public:
9087
void Compute(const framework::ExecutionContext& context) const override {
91-
auto ids_t = context.Input<Tensor>("Ids");
92-
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
93-
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
94-
95-
int N = d_table_t->dims()[0];
96-
int D = d_table_t->dims()[1];
97-
int K = ids_t->numel();
98-
const int32_t* ids = ids_t->data<int32_t>();
99-
const T* d_output = d_output_t->data<T>();
100-
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
101-
102-
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
103-
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
104-
t.constant(static_cast<T>(0));
105-
106-
dim3 threads(128, 8);
107-
dim3 grids(8, 1);
108-
LookupTableGrad<T, 128, 8, 8><<<
109-
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
88+
bool is_sparse = context.Attr<bool>("is_sparse");
89+
if (is_sparse) {
90+
auto* ids = context.Input<Tensor>("Ids");
91+
auto* table = context.Input<Tensor>("W");
92+
auto* d_output = context.Input<Tensor>(framework::GradVarName("Out"));
93+
auto* d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
94+
95+
auto* ids_data = ids->data<int64_t>();
96+
auto ids_dim = ids->dims();
97+
98+
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
99+
context.device_context())
100+
.stream();
101+
// copy GPU memory to CPU pinned memory
102+
framework::Vector<int64_t> new_rows;
103+
new_rows.resize(ids_dim[0]);
104+
auto gpu_place = boost::get<platform::GPUPlace>(context.GetPlace());
105+
106+
memory::Copy(platform::CPUPlace(), new_rows.data(), gpu_place, ids_data,
107+
ids_dim[0] * sizeof(int64_t), stream);
108+
109+
d_table->set_rows(new_rows);
110+
111+
auto* d_table_value = d_table->mutable_value();
112+
d_table_value->Resize({ids_dim[0], table->dims()[1]});
113+
d_table_value->mutable_data<T>(context.GetPlace());
114+
115+
auto* d_table_data = d_table_value->data<T>();
116+
auto* d_output_data = d_output->data<T>();
117+
PADDLE_ENFORCE_EQ(d_table_value->dims(), d_output->dims());
118+
memory::Copy(gpu_place, d_table_data, gpu_place, d_output_data,
119+
d_output->numel(), stream);
120+
121+
} else {
122+
auto ids_t = context.Input<Tensor>("Ids");
123+
auto d_output_t = context.Input<Tensor>(framework::GradVarName("Out"));
124+
auto d_table_t = context.Output<Tensor>(framework::GradVarName("W"));
125+
126+
int N = d_table_t->dims()[0];
127+
int D = d_table_t->dims()[1];
128+
int K = ids_t->numel();
129+
const int64_t* ids = ids_t->data<int64_t>();
130+
const T* d_output = d_output_t->data<T>();
131+
T* d_table = d_table_t->mutable_data<T>(context.GetPlace());
132+
133+
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
134+
t.device(context.GetEigenDevice<platform::GPUPlace>()) =
135+
t.constant(static_cast<T>(0));
136+
137+
dim3 threads(128, 8);
138+
dim3 grids(8, 1);
139+
LookupTableGrad<T, 128, 8,
140+
8><<<grids, threads, 0,
141+
reinterpret_cast<const platform::CUDADeviceContext&>(
110142
context.device_context())
111143
.stream()>>>(d_table, d_output, ids, N, K, D);
144+
}
112145
}
113146
};
114147

115148
} // namespace operators
116149
} // namespace paddle
117150

118151
namespace ops = paddle::operators;
119-
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>);
120-
REGISTER_OP_GPU_KERNEL(lookup_table_grad,
121-
ops::LookupTableGradCUDAKernel<float>);
152+
REGISTER_OP_GPU_KERNEL(lookup_table, ops::LookupTableCUDAKernel<float>,
153+
ops::LookupTableCUDAKernel<double>);
154+
REGISTER_OP_GPU_KERNEL(lookup_table_grad, ops::LookupTableGradCUDAKernel<float>,
155+
ops::LookupTableGradCUDAKernel<double>);

0 commit comments

Comments
 (0)