Skip to content

Commit 89fb196

Browse files
author
zhangkaihuo
authored
[cherry-pick-2.2.1]Opt topk (#37325)
目前的fused_attention_op不支持attn_mask=None的输入,本PR对此进行了补充,并补充了相应的单测逻辑。
1 parent d31d597 commit 89fb196

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

paddle/fluid/operators/top_k_v2_op.cu

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
8383

8484
if (k > input_width) k = input_width;
8585

86-
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
86+
// The conclusion is drawn from the data through multiple sets of
87+
// statistics
88+
if (input_width >= 128 && k >= input_width * 0.75) {
8789
if (SortTopk<T>(dev_ctx, input, input_width, input_height, k, output,
8890
indices, largest)) {
8991
// Successed, return.
@@ -159,8 +161,9 @@ class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
159161

160162
if (k > input_width) k = input_width;
161163

162-
if (((input_width <= 1024 && input_height <= 2048) || k >= 128 ||
163-
k == input_width)) {
164+
// The conclusion is drawn from the data through multiple sets of
165+
// statistics
166+
if (input_width >= 128 && k >= input_width * 0.75) {
164167
if (SortTopk<T>(dev_ctx, &trans_input, input_width, input_height, k,
165168
&trans_out, &trans_ind, largest)) {
166169
// last step, tranpose back the indices and output

0 commit comments

Comments
 (0)