Skip to content

Commit 778b71f

Browse files
author
baiyf
authored
Optimize bipartite_match_op in large scale input (#11730)
* optimize bipartite_match_op in large scale input
1 parent c228977 commit 778b71f

File tree

2 files changed

+84
-31
lines changed

2 files changed

+84
-31
lines changed

paddle/fluid/operators/detection/bipartite_match_op.cc

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,53 +51,89 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
5151
}
5252
};
5353

54+
template <class T>
55+
bool DistPairDescend(std::tuple<int, int, T> pair1,
56+
std::tuple<int, int, T> pair2) {
57+
return std::get<2>(pair1) > std::get<2>(pair2);
58+
}
59+
5460
template <typename T>
5561
class BipartiteMatchKernel : public framework::OpKernel<T> {
5662
public:
5763
// The match_indices must be initialized to -1 at first.
5864
// The match_dist must be initialized to 0 at first.
5965
void BipartiteMatch(const Tensor& dist, int* match_indices,
6066
T* match_dist) const {
61-
constexpr T kEPS = static_cast<T>(1e-6);
6267
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
6368
int64_t row = dist.dims()[0];
6469
int64_t col = dist.dims()[1];
6570
auto* dist_data = dist.data<T>();
66-
std::vector<int> row_pool;
67-
for (int i = 0; i < row; ++i) {
68-
row_pool.push_back(i);
69-
}
70-
while (row_pool.size() > 0) {
71-
int max_idx = -1;
72-
int max_row_idx = -1;
73-
T max_dist = -1;
74-
for (int64_t j = 0; j < col; ++j) {
75-
if (match_indices[j] != -1) {
76-
continue;
71+
// Test result: When row==130 the speed of these two methods almost the same
72+
if (row >= 130) {
73+
std::vector<std::tuple<int, int, T>> match_pair;
74+
75+
for (int64_t i = 0; i < row; ++i) {
76+
for (int64_t j = 0; j < col; ++j) {
77+
match_pair.push_back(std::make_tuple(i, j, dist_data[i * col + j]));
7778
}
78-
for (size_t k = 0; k < row_pool.size(); ++k) {
79-
int m = row_pool[k];
80-
// distance is 0 between m-th row and j-th column
81-
if (dist_data[m * col + j] < kEPS) {
79+
}
80+
std::sort(match_pair.begin(), match_pair.end(), DistPairDescend<T>);
81+
std::vector<int> row_indices(row, -1);
82+
83+
int64_t idx = 0;
84+
for (int64_t k = 0; k < row * col; ++k) {
85+
int64_t i = std::get<0>(match_pair[k]);
86+
int64_t j = std::get<1>(match_pair[k]);
87+
T dist = std::get<2>(match_pair[k]);
88+
89+
if (idx >= row) {
90+
break;
91+
}
92+
if (match_indices[j] == -1 && row_indices[i] == -1 && dist > 0) {
93+
match_indices[j] = i;
94+
row_indices[i] = j;
95+
match_dist[j] = dist;
96+
idx += 1;
97+
}
98+
}
99+
} else {
100+
constexpr T kEPS = static_cast<T>(1e-6);
101+
std::vector<int> row_pool;
102+
for (int i = 0; i < row; ++i) {
103+
row_pool.push_back(i);
104+
}
105+
while (row_pool.size() > 0) {
106+
int max_idx = -1;
107+
int max_row_idx = -1;
108+
T max_dist = -1;
109+
for (int64_t j = 0; j < col; ++j) {
110+
if (match_indices[j] != -1) {
82111
continue;
83112
}
84-
if (dist_data[m * col + j] > max_dist) {
85-
max_idx = j;
86-
max_row_idx = m;
87-
max_dist = dist_data[m * col + j];
113+
for (size_t k = 0; k < row_pool.size(); ++k) {
114+
int m = row_pool[k];
115+
// distance is 0 between m-th row and j-th column
116+
if (dist_data[m * col + j] < kEPS) {
117+
continue;
118+
}
119+
if (dist_data[m * col + j] > max_dist) {
120+
max_idx = j;
121+
max_row_idx = m;
122+
max_dist = dist_data[m * col + j];
123+
}
88124
}
89125
}
90-
}
91-
if (max_idx == -1) {
92-
// Cannot find good match.
93-
break;
94-
} else {
95-
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
96-
match_indices[max_idx] = max_row_idx;
97-
match_dist[max_idx] = max_dist;
98-
// Erase the row index.
99-
row_pool.erase(
100-
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
126+
if (max_idx == -1) {
127+
// Cannot find good match.
128+
break;
129+
} else {
130+
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
131+
match_indices[max_idx] = max_row_idx;
132+
match_dist[max_idx] = max_dist;
133+
// Erase the row index.
134+
row_pool.erase(
135+
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
136+
}
101137
}
102138
}
103139
}

python/paddle/fluid/tests/unittests/test_bipartite_match_op.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,23 @@ def test_check_output(self):
114114
self.check_output()
115115

116116

117+
class TestBipartiteMatchOpWithoutLoDLargeScaleInput(OpTest):
118+
def setUp(self):
119+
self.op_type = 'bipartite_match'
120+
lod = [[300]]
121+
dist = np.random.random((300, 17)).astype('float32')
122+
match_indices, match_dist = batch_bipartite_match(dist, lod[0])
123+
124+
self.inputs = {'DistMat': dist}
125+
self.outputs = {
126+
'ColToRowMatchIndices': match_indices,
127+
'ColToRowMatchDist': match_dist,
128+
}
129+
130+
def test_check_output(self):
131+
self.check_output()
132+
133+
117134
class TestBipartiteMatchOpWithPerPredictionType(OpTest):
118135
def setUp(self):
119136
self.op_type = 'bipartite_match'

0 commit comments

Comments
 (0)