Skip to content

Commit 41e4f7e

Browse files
authored
Optimize Topk when height is large. (#13710)
1 parent 65ed45a commit 41e4f7e

File tree

1 file changed

+64
-27
lines changed

1 file changed

+64
-27
lines changed

paddle/fluid/operators/top_k_op.cu

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
256256
* 3. go to the second setp, until one thread's topk value is null;
257257
* 4. go to the first setp, until get the topk value.
258258
*/
259+
259260
template <typename T, int MaxLength, int BlockSize>
260261
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
261-
const T* src, int lds, int dim, int k) {
262+
const T* src, int lds, int dim, int k,
263+
int grid_dim, int num) {
262264
__shared__ Pair<T> sh_topk[BlockSize];
263265
__shared__ int maxid[BlockSize / 2];
264266
const int tid = threadIdx.x;
265267
const int warp = threadIdx.x / 32;
266-
output += blockIdx.x * output_stride;
267-
indices += blockIdx.x * k;
268268

269-
Pair<T> topk[MaxLength];
270-
int beam = MaxLength;
271-
Pair<T> max;
272-
bool is_empty = false;
273-
bool firststep = true;
269+
const int bid = blockIdx.x;
270+
for (int i = bid; i < num; i += grid_dim) {
271+
output += i * output_stride;
272+
indices += i * k;
273+
274+
Pair<T> topk[MaxLength];
275+
int beam = MaxLength;
276+
Pair<T> max;
277+
bool is_empty = false;
278+
bool firststep = true;
279+
280+
for (int k = 0; k < MaxLength; k++) {
281+
topk[k].set(-INFINITY, -1);
282+
}
283+
while (k) {
284+
ThreadGetTopK<T, MaxLength, BlockSize>(
285+
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
274286

275-
for (int k = 0; k < MaxLength; k++) {
276-
topk[k].set(-INFINITY, -1);
287+
sh_topk[tid] = topk[0];
288+
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
289+
&indices, &beam, &k, tid, warp);
290+
}
277291
}
278-
while (k) {
279-
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, k,
280-
src + blockIdx.x * lds, &firststep,
281-
&is_empty, &max, dim, tid);
282-
283-
sh_topk[tid] = topk[0];
284-
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
285-
&indices, &beam, &k, tid, warp);
292+
}
293+
294+
inline static int GetDesiredBlockDim(int dim) {
295+
if (dim > 128) {
296+
return 256;
297+
} else if (dim > 64) {
298+
return 128;
299+
} else if (dim > 32) {
300+
return 64;
301+
} else {
302+
return 32;
286303
}
287304
}
288305

306+
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
307+
case (dim): { \
308+
constexpr auto kBlockDim = (dim); \
309+
__VA_ARGS__; \
310+
} break
311+
312+
#define FIXED_BLOCK_DIM(...) \
313+
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
314+
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
315+
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
316+
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
317+
289318
template <typename T>
290319
class TopkOpCUDAKernel : public framework::OpKernel<T> {
291320
public:
@@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
310339
// NOTE: pass lds and dim same to input width.
311340
// NOTE: old matrix implementation of stride is different to eigen.
312341
// TODO(typhoonzero): refine this kernel.
313-
dim3 threads(256, 1);
314-
dim3 grid(input_height, 1);
315-
316-
KeMatrixTopK<T, 5, 256><<<
317-
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
318-
ctx.device_context())
319-
.stream()>>>(
320-
output_data, output->dims()[1], indices_data, input_data, input_width,
321-
input_width, static_cast<int>(k));
342+
const int kMaxHeight = 2048;
343+
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
344+
auto& dev_ctx = ctx.cuda_device_context();
345+
346+
switch (GetDesiredBlockDim(input_width)) {
347+
FIXED_BLOCK_DIM(
348+
KeMatrixTopK<T, 5,
349+
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
350+
output_data, output->dims()[1], indices_data, input_data,
351+
input_width, input_width, static_cast<int>(k), gridx,
352+
input_height));
353+
default:
354+
PADDLE_THROW("Error");
355+
}
322356
}
323357
};
324358

359+
#undef FIXED_BLOCK_DIM_BASE
360+
#undef FIXED_BLOCK_DIM
361+
325362
} // namespace operators
326363
} // namespace paddle
327364

0 commit comments

Comments
 (0)