@@ -51,53 +51,89 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
51
51
}
52
52
};
53
53
54
+ template <class T >
55
+ bool DistPairDescend (std::tuple<int , int , T> pair1,
56
+ std::tuple<int , int , T> pair2) {
57
+ return std::get<2 >(pair1) > std::get<2 >(pair2);
58
+ }
59
+
54
60
template <typename T>
55
61
class BipartiteMatchKernel : public framework ::OpKernel<T> {
56
62
public:
57
63
// The match_indices must be initialized to -1 at first.
58
64
// The match_dist must be initialized to 0 at first.
59
65
void BipartiteMatch (const Tensor& dist, int * match_indices,
60
66
T* match_dist) const {
61
- constexpr T kEPS = static_cast <T>(1e-6 );
62
67
PADDLE_ENFORCE_EQ (dist.dims ().size (), 2 , " The rank of dist must be 2." );
63
68
int64_t row = dist.dims ()[0 ];
64
69
int64_t col = dist.dims ()[1 ];
65
70
auto * dist_data = dist.data <T>();
66
- std::vector<int > row_pool;
67
- for (int i = 0 ; i < row; ++i) {
68
- row_pool.push_back (i);
69
- }
70
- while (row_pool.size () > 0 ) {
71
- int max_idx = -1 ;
72
- int max_row_idx = -1 ;
73
- T max_dist = -1 ;
74
- for (int64_t j = 0 ; j < col; ++j) {
75
- if (match_indices[j] != -1 ) {
76
- continue ;
71
+ // Test result: When row==130 the speed of these two methods almost the same
72
+ if (row >= 130 ) {
73
+ std::vector<std::tuple<int , int , T>> match_pair;
74
+
75
+ for (int64_t i = 0 ; i < row; ++i) {
76
+ for (int64_t j = 0 ; j < col; ++j) {
77
+ match_pair.push_back (std::make_tuple (i, j, dist_data[i * col + j]));
77
78
}
78
- for (size_t k = 0 ; k < row_pool.size (); ++k) {
79
- int m = row_pool[k];
80
- // distance is 0 between m-th row and j-th column
81
- if (dist_data[m * col + j] < kEPS ) {
79
+ }
80
+ std::sort (match_pair.begin (), match_pair.end (), DistPairDescend<T>);
81
+ std::vector<int > row_indices (row, -1 );
82
+
83
+ int64_t idx = 0 ;
84
+ for (int64_t k = 0 ; k < row * col; ++k) {
85
+ int64_t i = std::get<0 >(match_pair[k]);
86
+ int64_t j = std::get<1 >(match_pair[k]);
87
+ T dist = std::get<2 >(match_pair[k]);
88
+
89
+ if (idx >= row) {
90
+ break ;
91
+ }
92
+ if (match_indices[j] == -1 && row_indices[i] == -1 && dist > 0 ) {
93
+ match_indices[j] = i;
94
+ row_indices[i] = j;
95
+ match_dist[j] = dist;
96
+ idx += 1 ;
97
+ }
98
+ }
99
+ } else {
100
+ constexpr T kEPS = static_cast <T>(1e-6 );
101
+ std::vector<int > row_pool;
102
+ for (int i = 0 ; i < row; ++i) {
103
+ row_pool.push_back (i);
104
+ }
105
+ while (row_pool.size () > 0 ) {
106
+ int max_idx = -1 ;
107
+ int max_row_idx = -1 ;
108
+ T max_dist = -1 ;
109
+ for (int64_t j = 0 ; j < col; ++j) {
110
+ if (match_indices[j] != -1 ) {
82
111
continue ;
83
112
}
84
- if (dist_data[m * col + j] > max_dist) {
85
- max_idx = j;
86
- max_row_idx = m;
87
- max_dist = dist_data[m * col + j];
113
+ for (size_t k = 0 ; k < row_pool.size (); ++k) {
114
+ int m = row_pool[k];
115
+ // distance is 0 between m-th row and j-th column
116
+ if (dist_data[m * col + j] < kEPS ) {
117
+ continue ;
118
+ }
119
+ if (dist_data[m * col + j] > max_dist) {
120
+ max_idx = j;
121
+ max_row_idx = m;
122
+ max_dist = dist_data[m * col + j];
123
+ }
88
124
}
89
125
}
90
- }
91
- if (max_idx == - 1 ) {
92
- // Cannot find good match.
93
- break ;
94
- } else {
95
- PADDLE_ENFORCE_EQ ( match_indices[max_idx], - 1 ) ;
96
- match_indices [max_idx] = max_row_idx ;
97
- match_dist[max_idx] = max_dist;
98
- // Erase the row index.
99
- row_pool.erase (
100
- std::find (row_pool. begin (), row_pool. end (), max_row_idx));
126
+ if (max_idx == - 1 ) {
127
+ // Cannot find good match.
128
+ break ;
129
+ } else {
130
+ PADDLE_ENFORCE_EQ (match_indices[max_idx], - 1 );
131
+ match_indices[max_idx] = max_row_idx ;
132
+ match_dist [max_idx] = max_dist ;
133
+ // Erase the row index.
134
+ row_pool. erase (
135
+ std::find ( row_pool.begin (), row_pool. end (), max_row_idx));
136
+ }
101
137
}
102
138
}
103
139
}
0 commit comments