Skip to content

Commit c7379a7

Browse files
qingqing01wanghaoshuang
authored andcommitted
Fix top_k op (#14034)
1. Fix CUDA kernel when height is large than 2048. 2. Support input with more than 2D. 3. Fix unit test when k is large than 1. 4. Enhence unit testing. test=develop
1 parent d3e5255 commit c7379a7

File tree

3 files changed

+70
-31
lines changed

3 files changed

+70
-31
lines changed

paddle/fluid/operators/top_k_op.cu

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
262262
const T* src, int lds, int dim, int k,
263263
int grid_dim, int num) {
264264
__shared__ Pair<T> sh_topk[BlockSize];
265-
__shared__ int maxid[BlockSize / 2];
266265
const int tid = threadIdx.x;
267266
const int warp = threadIdx.x / 32;
268267

269268
const int bid = blockIdx.x;
270269
for (int i = bid; i < num; i += grid_dim) {
271-
output += i * output_stride;
272-
indices += i * k;
273-
270+
int top_num = k;
271+
__shared__ int maxid[BlockSize / 2];
272+
T* out = output + i * output_stride;
273+
int64_t* inds = indices + i * k;
274274
Pair<T> topk[MaxLength];
275275
int beam = MaxLength;
276276
Pair<T> max;
277277
bool is_empty = false;
278278
bool firststep = true;
279279

280-
for (int k = 0; k < MaxLength; k++) {
281-
topk[k].set(-INFINITY, -1);
280+
for (int j = 0; j < MaxLength; j++) {
281+
topk[j].set(-INFINITY, -1);
282282
}
283-
while (k) {
283+
while (top_num) {
284284
ThreadGetTopK<T, MaxLength, BlockSize>(
285285
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
286286

287287
sh_topk[tid] = topk[0];
288-
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
289-
&indices, &beam, &k, tid, warp);
288+
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
289+
&beam, &top_num, tid, warp);
290290
}
291291
}
292292
}
@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
327327
size_t k = static_cast<int>(ctx.Attr<int>("k"));
328328

329329
const T* input_data = input->data<T>();
330-
331330
T* output_data = output->mutable_data<T>(ctx.GetPlace());
332331
// FIXME(typhoonzero): data is always converted to type T?
333332
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
334333

335-
size_t input_height = input->dims()[0];
336-
size_t input_width = input->dims()[1];
334+
framework::DDim inputdims = input->dims();
335+
const size_t input_height = framework::product(
336+
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
337+
const size_t input_width = inputdims[inputdims.size() - 1];
338+
337339
if (k > input_width) k = input_width;
338340

339341
// NOTE: pass lds and dim same to input width.
@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
342344
const int kMaxHeight = 2048;
343345
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
344346
auto& dev_ctx = ctx.cuda_device_context();
345-
346347
switch (GetDesiredBlockDim(input_width)) {
347348
FIXED_BLOCK_DIM(
348349
KeMatrixTopK<T, 5,
349350
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
350-
output_data, output->dims()[1], indices_data, input_data,
351-
input_width, input_width, static_cast<int>(k), gridx,
352-
input_height));
351+
output_data, k, indices_data, input_data, input_width,
352+
input_width, static_cast<int>(k), gridx, input_height));
353353
default:
354354
PADDLE_THROW("Error");
355355
}

paddle/fluid/operators/top_k_op.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> {
3434
public:
3535
void Compute(const framework::ExecutionContext& ctx) const override {
3636
// Get the top k elements of each row of input tensor
37-
// FIXME: only deal with matrix(2d tensor).
3837
auto* input = ctx.Input<Tensor>("X");
3938
auto* output = ctx.Output<Tensor>("Out");
4039
auto* indices = ctx.Output<Tensor>("Indices");
@@ -44,16 +43,14 @@ class TopkKernel : public framework::OpKernel<T> {
4443
T* output_data = output->mutable_data<T>(ctx.GetPlace());
4544
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
4645

47-
auto eg_input = EigenMatrix<T>::From(*input);
48-
4946
// reshape input to a flattern matrix(like flat_inner_dims)
5047
framework::DDim inputdims = input->dims();
5148
const size_t row = framework::product(
5249
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
5350
const size_t col = inputdims[inputdims.size() - 1];
5451
Eigen::DSizes<int, 2> flat2dims(row, col);
5552
// NOTE: eigen shape doesn't affect paddle tensor.
56-
eg_input.reshape(flat2dims);
53+
auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
5754

5855
#ifdef PADDLE_WITH_MKLML
5956
#pragma omp parallel for

python/paddle/fluid/tests/unittests/test_top_k_op.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,27 @@
2121

2222
class TestTopkOp(OpTest):
2323
def setUp(self):
24+
self.set_args()
2425
self.op_type = "top_k"
25-
k = 1
26-
input = np.random.random((32, 84)).astype("float32")
27-
output = np.ndarray((32, k))
28-
indices = np.ndarray((32, k)).astype("int64")
26+
k = self.top_k
27+
input = np.random.random((self.row, k)).astype("float32")
28+
output = np.ndarray((self.row, k))
29+
indices = np.ndarray((self.row, k)).astype("int64")
2930

3031
self.inputs = {'X': input}
3132
self.attrs = {'k': k}
3233

33-
for rowid in range(32):
34+
for rowid in range(self.row):
3435
row = input[rowid]
35-
output[rowid] = np.sort(row)[-k:]
36-
indices[rowid] = row.argsort()[-k:]
36+
output[rowid] = np.sort(row)[::-1][:k]
37+
indices[rowid] = row.argsort()[::-1][:k]
3738

3839
self.outputs = {'Out': output, 'Indices': indices}
3940

41+
def set_args(self):
42+
self.row = 32
43+
self.top_k = 1
44+
4045
def test_check_output(self):
4146
self.check_output()
4247

@@ -50,20 +55,57 @@ def setUp(self):
5055
output = np.ndarray((64, k))
5156
indices = np.ndarray((64, k)).astype("int64")
5257

53-
# FIXME: should use 'X': input for a 3d input
54-
self.inputs = {'X': input_flat_2d}
58+
self.inputs = {'X': input}
5559
self.attrs = {'k': k}
5660

5761
for rowid in range(64):
5862
row = input_flat_2d[rowid]
59-
output[rowid] = np.sort(row)[-k:]
60-
indices[rowid] = row.argsort()[-k:]
63+
output[rowid] = np.sort(row)[::-1][:k]
64+
indices[rowid] = row.argsort()[::-1][:k]
65+
66+
self.outputs = {
67+
'Out': output.reshape((32, 2, k)),
68+
'Indices': indices.reshape((32, 2, k))
69+
}
70+
71+
def test_check_output(self):
72+
self.check_output()
73+
74+
75+
class TestTopkOp2(OpTest):
76+
def setUp(self):
77+
self.op_type = "top_k"
78+
k = 1
79+
m = 2056
80+
input = np.random.random((m, 84)).astype("float32")
81+
output = np.ndarray((m, k))
82+
indices = np.ndarray((m, k)).astype("int64")
83+
84+
self.inputs = {'X': input}
85+
self.attrs = {'k': k}
86+
87+
for rowid in range(m):
88+
row = input[rowid]
89+
output[rowid] = -np.sort(-row)[:k]
90+
indices[rowid] = (-row).argsort()[:k]
6191

6292
self.outputs = {'Out': output, 'Indices': indices}
6393

6494
def test_check_output(self):
6595
self.check_output()
6696

6797

98+
class TestTopkOp3(TestTopkOp):
99+
def set_args(self):
100+
self.row = 2056
101+
self.top_k = 3
102+
103+
104+
class TestTopkOp4(TestTopkOp):
105+
def set_args(self):
106+
self.row = 40000
107+
self.top_k = 1
108+
109+
68110
if __name__ == "__main__":
69111
unittest.main()

0 commit comments

Comments
 (0)