Skip to content

Commit 2b8fd70

Browse files
authored
fix bug of top_k npu op (#36175)
1 parent c79de72 commit 2b8fd70

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

paddle/fluid/operators/top_k_op_npu.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ class TopkNPUKernel : public framework::OpKernel<T> {
5151
indices->mutable_data<int64_t>(ctx.GetPlace());
5252

5353
// prepare assit
54-
auto dim = input->dims().size();
54+
auto size = input->dims().size();
55+
// dim is the last dimension of input
56+
auto dim = input->dims()[size - 1];
5557
framework::Tensor assist_seq_tensor;
5658
assist_seq_tensor.Resize({2 * dim});
5759
assist_seq_tensor.mutable_data<T>(ctx.GetPlace());

python/paddle/fluid/tests/unittests/npu/test_top_k_op_npu.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import paddle
2323
import paddle.fluid as fluid
2424
from paddle.fluid import core
25+
from test_top_k_v2_op_npu import numpy_topk
2526

2627
paddle.enable_static()
2728
SEED = 2021
@@ -87,5 +88,40 @@ def test_check_output(self):
8788
self.check_output_with_place(self.place)
8889

8990

91+
class TestTopkV3(OpTest):
92+
def setUp(self):
93+
self.set_npu()
94+
self.place = paddle.NPUPlace(0)
95+
self.op_type = "top_k"
96+
97+
self.init_dtype()
98+
self.set_input_data()
99+
self.set_attrs()
100+
output, indices = numpy_topk(
101+
self.input_data, axis=self.axis, k=self.k, largest=True)
102+
103+
self.inputs = {'X': self.input_data}
104+
self.attrs = {'k': self.k, 'axis': self.axis}
105+
self.outputs = {'Out': output, 'Indices': indices}
106+
107+
def set_npu(self):
108+
self.__class__.use_npu = True
109+
self.__class__.no_need_check_grad = True
110+
111+
def init_dtype(self):
112+
self.dtype = np.float16
113+
114+
def test_check_output(self):
115+
self.check_output_with_place(self.place)
116+
117+
def set_attrs(self):
118+
self.k = 3
119+
self.axis = 1
120+
121+
def set_input_data(self):
122+
self.input_data = np.random.choice(
123+
10000, size=(10, 20), replace=False).astype(self.dtype)
124+
125+
90126
if __name__ == '__main__':
91127
unittest.main()

0 commit comments

Comments
 (0)