Skip to content

Commit 4948f7b

Browse files
authored
Enhance bipartite_match_op to support argmax matching after bipartite matching. (#8580)
* Enhance bipartite_match_op to support argmax matching after bipartite matching. * Fix typo error.
1 parent dce0383 commit 4948f7b

File tree

3 files changed

+112
-8
lines changed

3 files changed

+112
-8
lines changed

paddle/fluid/operators/bipartite_match_op.cc

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
9494
}
9595
}
9696

97+
void ArgMaxMatch(const Tensor& dist, int* match_indices, T* match_dist,
98+
T overlap_threshold) const {
99+
constexpr T kEPS = static_cast<T>(1e-6);
100+
int64_t row = dist.dims()[0];
101+
int64_t col = dist.dims()[1];
102+
auto* dist_data = dist.data<T>();
103+
for (int64_t j = 0; j < col; ++j) {
104+
if (match_indices[j] != -1) {
105+
// the j-th column has been matched to one entity.
106+
continue;
107+
}
108+
int max_row_idx = -1;
109+
T max_dist = -1;
110+
for (int i = 0; i < row; ++i) {
111+
T dist = dist_data[i * col + j];
112+
if (dist < kEPS) {
113+
// distance is 0 between m-th row and j-th column
114+
continue;
115+
}
116+
if (dist >= overlap_threshold && dist > max_dist) {
117+
max_row_idx = i;
118+
max_dist = dist;
119+
}
120+
}
121+
if (max_row_idx != -1) {
122+
PADDLE_ENFORCE_EQ(match_indices[j], -1);
123+
match_indices[j] = max_row_idx;
124+
match_dist[j] = max_dist;
125+
}
126+
}
127+
}
128+
97129
void Compute(const framework::ExecutionContext& context) const override {
98130
auto* dist_mat = context.Input<LoDTensor>("DistMat");
99131
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
@@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
120152

121153
int* indices = match_indices->data<int>();
122154
T* dist = match_dist->data<T>();
155+
auto type = context.Attr<std::string>("match_type");
156+
auto threshold = context.Attr<float>("dist_threshold");
123157
if (n == 1) {
124158
BipartiteMatch(*dist_mat, indices, dist);
159+
if (type == "per_prediction") {
160+
ArgMaxMatch(*dist_mat, indices, dist, threshold);
161+
}
125162
} else {
126163
auto lod = dist_mat->lod().back();
127164
for (size_t i = 0; i < lod.size() - 1; ++i) {
128165
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
129166
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
167+
if (type == "per_prediction") {
168+
ArgMaxMatch(one_ins, indices + i * col, dist + i * col, threshold);
169+
}
130170
}
131171
}
132172
}
@@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
147187
"This tensor can contain LoD information to represent a batch of "
148188
"inputs. One instance of this batch can contain different numbers of "
149189
"entities.");
190+
AddAttr<std::string>(
191+
"match_type",
192+
"(string, defalut: per_prediction) "
193+
"The type of matching method, should be 'bipartite' or "
194+
"'per_prediction', 'bipartite' by defalut.")
195+
.SetDefault("bipartite")
196+
.InEnum({"bipartite", "per_prediction"});
197+
AddAttr<float>(
198+
"dist_threshold",
199+
"(float, defalut: 0.5) "
200+
"If `match_type` is 'per_prediction', this threshold is to determine "
201+
"the extra matching bboxes based on the maximum distance.")
202+
.SetDefault(0.5);
150203
AddOutput("ColToRowMatchIndices",
151204
"(Tensor) A 2-D Tensor with shape [N, M] in int type. "
152205
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
@@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can
168221
find the matched column for each row, also can find the matched row for
169222
each column. And this operator only calculate matched indices from column
170223
to row. For each instance, the number of matched indices is the number of
171-
of columns of the input ditance matrix.
224+
of columns of the input distance matrix.
172225
173226
There are two outputs to save matched indices and distance.
174-
A simple description, this algothrim matched the best (maximum distance)
227+
A simple description, this algorithm matched the best (maximum distance)
175228
row entity to the column entity and the matched indices are not duplicated
176229
in each row of ColToRowMatchIndices. If the column entity is not matched
177230
any row entity, set -1 in ColToRowMatchIndices.

python/paddle/fluid/layers/detection.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,10 @@ class number, M is number of bounding boxes. For each category
132132
return nmsed_outs
133133

134134

135-
def bipartite_match(dist_matrix, name=None):
135+
def bipartite_match(dist_matrix,
136+
match_type=None,
137+
dist_threshold=None,
138+
name=None):
136139
"""
137140
**Bipartite matchint operator**
138141
@@ -164,6 +167,11 @@ def bipartite_match(dist_matrix, name=None):
164167
This tensor can contain LoD information to represent a batch of
165168
inputs. One instance of this batch can contain different numbers of
166169
entities.
170+
match_type(string|None): The type of matching method, should be
171+
'bipartite' or 'per_prediction', 'bipartite' by defalut.
172+
dist_threshold(float|None): If `match_type` is 'per_prediction',
173+
this threshold is to determine the extra matching bboxes based
174+
on the maximum distance, 0.5 by defalut.
167175
Returns:
168176
match_indices(Variable): A 2-D Tensor with shape [N, M] in int type.
169177
N is the batch size. If match_indices[i][j] is -1, it
@@ -183,6 +191,10 @@ def bipartite_match(dist_matrix, name=None):
183191
helper.append_op(
184192
type='bipartite_match',
185193
inputs={'DistMat': dist_matrix},
194+
attrs={
195+
'match_type': match_type,
196+
'dist_threshold': dist_threshold,
197+
},
186198
outputs={
187199
'ColToRowMatchIndices': match_indices,
188200
'ColToRowMatchDist': match_distance
@@ -333,7 +345,7 @@ def ssd_loss(location,
333345
loc_loss_weight (float): Weight for localization loss, 1.0 by default.
334346
conf_loss_weight (float): Weight for confidence loss, 1.0 by default.
335347
match_type (str): The type of matching method during training, should
336-
be 'bipartite' or 'per_prediction'.
348+
be 'bipartite' or 'per_prediction', 'per_prediction' by defalut.
337349
mining_type (str): The hard example mining type, should be 'hard_example'
338350
or 'max_negative', now only support `max_negative`.
339351
@@ -381,7 +393,8 @@ def __reshape_to_2d(var):
381393
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
382394
iou = iou_similarity(x=gt_box, y=prior_box)
383395
# 1.2 Compute matched boundding box by bipartite matching algorithm.
384-
matched_indices, matched_dist = bipartite_match(iou)
396+
matched_indices, matched_dist = bipartite_match(iou, match_type,
397+
overlap_threshold)
385398

386399
# 2. Compute confidence for mining hard examples
387400
# 2.1. Get the target label based on matched indices

python/paddle/fluid/tests/unittests/test_bipartite_match_op.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,20 @@ def bipartite_match(distance, match_indices, match_dist):
4646
idx += 1
4747

4848

49-
def batch_bipartite_match(distance, lod):
49+
def argmax_match(distance, match_indices, match_dist, threshold):
50+
r, c = distance.shape
51+
for j in xrange(c):
52+
if match_indices[j] != -1:
53+
continue
54+
col_dist = distance[:, j]
55+
indices = np.argwhere(col_dist >= threshold).flatten()
56+
if len(indices) < 1:
57+
continue
58+
match_indices[j] = indices[np.argmax(col_dist[indices])]
59+
match_dist[j] = col_dist[match_indices[j]]
60+
61+
62+
def batch_bipartite_match(distance, lod, match_type=None, dist_threshold=None):
5063
"""Bipartite Matching algorithm for batch input.
5164
Arg:
5265
distance (numpy.array) : The distance of two entries with shape [M, N].
@@ -59,6 +72,9 @@ def batch_bipartite_match(distance, lod):
5972
for i in range(len(lod) - 1):
6073
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
6174
match_dist[i, :])
75+
if match_type == 'per_prediction':
76+
argmax_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
77+
match_dist[i, :], dist_threshold)
6278
return match_indices, match_dist
6379

6480

@@ -71,8 +87,8 @@ def setUp(self):
7187

7288
self.inputs = {'DistMat': (dist, lod)}
7389
self.outputs = {
74-
'ColToRowMatchIndices': (match_indices),
75-
'ColToRowMatchDist': (match_dist),
90+
'ColToRowMatchIndices': match_indices,
91+
'ColToRowMatchDist': match_dist,
7692
}
7793

7894
def test_check_output(self):
@@ -96,5 +112,27 @@ def test_check_output(self):
96112
self.check_output()
97113

98114

115+
class TestBipartiteMatchOpWithPerPredictionType(OpTest):
116+
def setUp(self):
117+
self.op_type = 'bipartite_match'
118+
lod = [[0, 5, 11, 23]]
119+
dist = np.random.random((23, 237)).astype('float32')
120+
match_indices, match_dist = batch_bipartite_match(dist, lod[0],
121+
'per_prediction', 0.5)
122+
123+
self.inputs = {'DistMat': (dist, lod)}
124+
self.outputs = {
125+
'ColToRowMatchIndices': match_indices,
126+
'ColToRowMatchDist': match_dist,
127+
}
128+
self.attrs = {
129+
'match_type': 'per_prediction',
130+
'dist_threshold': 0.5,
131+
}
132+
133+
def test_check_output(self):
134+
self.check_output()
135+
136+
99137
if __name__ == '__main__':
100138
unittest.main()

0 commit comments

Comments
 (0)