@@ -29,6 +29,10 @@ template <typename T, int MajorType = Eigen::RowMajor,
29
29
typename IndexType = Eigen::DenseIndex>
30
30
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
31
31
32
+ template <typename T, int MajorType = Eigen::RowMajor,
33
+ typename IndexType = Eigen::DenseIndex>
34
+ using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
35
+
32
36
template <typename DeviceContext, typename T>
33
37
class TopkKernel : public framework ::OpKernel<T> {
34
38
public:
@@ -57,17 +61,24 @@ class TopkKernel : public framework::OpKernel<T> {
57
61
framework::slice_ddim (inputdims, 0 , inputdims.size () - 1 ));
58
62
const size_t col = inputdims[inputdims.size () - 1 ];
59
63
Eigen::DSizes<int , 2 > flat2dims (row, col);
60
- // NOTE: eigen shape doesn't affect paddle tensor.
61
- auto eg_input = EigenMatrix<T>::Reshape (*input, inputdims.size () - 1 );
62
-
64
+ // NOTE: eigen shape doesn't affect paddle tensor.
63
65
#ifdef PADDLE_WITH_MKLML
64
66
#pragma omp parallel for
65
67
#endif
66
68
for (size_t i = 0 ; i < row; i++) {
67
69
std::vector<std::pair<T, size_t >> vec;
68
70
vec.reserve (col);
69
- for (size_t j = 0 ; j < col; j++) {
70
- vec.push_back (std::pair<T, size_t >(eg_input (i, j), j));
71
+ // 1D vector
72
+ if (inputdims.size () == 1 ) {
73
+ auto eg_input = EigenVector<T>::Flatten (*input);
74
+ for (size_t j = 0 ; j < col; j++) {
75
+ vec.push_back (std::pair<T, size_t >(eg_input (j), j));
76
+ }
77
+ } else {
78
+ auto eg_input = EigenMatrix<T>::Reshape (*input, inputdims.size () - 1 );
79
+ for (size_t j = 0 ; j < col; j++) {
80
+ vec.push_back (std::pair<T, size_t >(eg_input (i, j), j));
81
+ }
71
82
}
72
83
73
84
std::partial_sort (
0 commit comments