Skip to content

Commit 856536b

Browse files
cherry-pick Fix topk cannot handle 1D vector bug (#18466)
Add path to handle 1D vector
1 parent e616c3d commit 856536b

File tree

2 files changed

+38
-5
lines changed

2 files changed

+38
-5
lines changed

paddle/fluid/operators/top_k_op.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor,
2929
typename IndexType = Eigen::DenseIndex>
3030
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
3131

32+
template <typename T, int MajorType = Eigen::RowMajor,
33+
typename IndexType = Eigen::DenseIndex>
34+
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
35+
3236
template <typename DeviceContext, typename T>
3337
class TopkKernel : public framework::OpKernel<T> {
3438
public:
@@ -57,17 +61,24 @@ class TopkKernel : public framework::OpKernel<T> {
5761
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
5862
const size_t col = inputdims[inputdims.size() - 1];
5963
Eigen::DSizes<int, 2> flat2dims(row, col);
60-
// NOTE: eigen shape doesn't affect paddle tensor.
61-
auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
62-
64+
// NOTE: eigen shape doesn't affect paddle tensor.
6365
#ifdef PADDLE_WITH_MKLML
6466
#pragma omp parallel for
6567
#endif
6668
for (size_t i = 0; i < row; i++) {
6769
std::vector<std::pair<T, size_t>> vec;
6870
vec.reserve(col);
69-
for (size_t j = 0; j < col; j++) {
70-
vec.push_back(std::pair<T, size_t>(eg_input(i, j), j));
71+
// 1D vector
72+
if (inputdims.size() == 1) {
73+
auto eg_input = EigenVector<T>::Flatten(*input);
74+
for (size_t j = 0; j < col; j++) {
75+
vec.push_back(std::pair<T, size_t>(eg_input(j), j));
76+
}
77+
} else {
78+
auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
79+
for (size_t j = 0; j < col; j++) {
80+
vec.push_back(std::pair<T, size_t>(eg_input(i, j), j));
81+
}
7182
}
7283

7384
std::partial_sort(

python/paddle/fluid/tests/unittests/test_top_k_op.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,28 @@ def test_check_output(self):
8787
self.check_output()
8888

8989

90+
class TestTopkOp1(OpTest):
91+
def setUp(self):
92+
self.op_type = "top_k"
93+
k = 2
94+
m = 2056
95+
input = np.random.random(m).astype("float32")
96+
output = np.ndarray(k)
97+
indices = np.ndarray(k).astype("int64")
98+
99+
self.inputs = {'X': input}
100+
self.attrs = {'k': k}
101+
102+
row = input
103+
output = -np.sort(-row)[:k]
104+
indices = (-row).argsort()[:k]
105+
106+
self.outputs = {'Out': output, 'Indices': indices}
107+
108+
def test_check_output(self):
109+
self.check_output()
110+
111+
90112
class TestTopkOp2(OpTest):
91113
def setUp(self):
92114
self.op_type = "top_k"

0 commit comments

Comments
 (0)