Skip to content

Commit 2b19a68

Browse files
authored
Merge pull request #7695 from qingqing01/bipartite_match_op
Add bipartite matching operator and unit testing.
2 parents 479c861 + e44dedf commit 2b19a68

File tree

2 files changed

+290
-0
lines changed

2 files changed

+290
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/op_registry.h"
16+
#include "paddle/operators/math/math_function.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
using LoDTensor = framework::LoDTensor;
23+
24+
constexpr char kEPS = 1e-6;
25+
26+
class BipartiteMatchOp : public framework::OperatorWithKernel {
27+
public:
28+
using framework::OperatorWithKernel::OperatorWithKernel;
29+
30+
void InferShape(framework::InferShapeContext* ctx) const override {
31+
PADDLE_ENFORCE(ctx->HasInput("DistMat"),
32+
"Input(DistMat) of BipartiteMatch should not be null.");
33+
34+
auto dims = ctx->GetInputDim("DistMat");
35+
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2.");
36+
37+
ctx->SetOutputDim("ColToRowMatchIndices", dims);
38+
ctx->SetOutputDim("ColToRowMatchDis", dims);
39+
}
40+
};
41+
42+
template <typename T>
43+
class BipartiteMatchKernel : public framework::OpKernel<T> {
44+
public:
45+
// The match_indices must be initialized to -1 at first.
46+
// The match_dist must be initialized to 0 at first.
47+
void BipartiteMatch(const Tensor& dist, int* match_indices,
48+
T* match_dist) const {
49+
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
50+
int64_t row = dist.dims()[0];
51+
int64_t col = dist.dims()[1];
52+
auto* dist_data = dist.data<T>();
53+
std::vector<int> row_pool;
54+
for (int i = 0; i < row; ++i) {
55+
row_pool.push_back(i);
56+
}
57+
while (row_pool.size() > 0) {
58+
int max_idx = -1;
59+
int max_row_idx = -1;
60+
T max_dist = -1;
61+
for (int64_t j = 0; j < col; ++j) {
62+
if (match_indices[j] != -1) {
63+
continue;
64+
}
65+
for (size_t k = 0; k < row_pool.size(); ++k) {
66+
int m = row_pool[k];
67+
// distance is 0 between m-th row and j-th column
68+
if (dist_data[m * col + j] < kEPS) {
69+
continue;
70+
}
71+
if (dist_data[m * col + j] > max_dist) {
72+
max_idx = j;
73+
max_row_idx = m;
74+
max_dist = dist_data[m * col + j];
75+
}
76+
}
77+
}
78+
if (max_idx == -1) {
79+
// Cannot find good match.
80+
break;
81+
} else {
82+
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
83+
match_indices[max_idx] = max_row_idx;
84+
match_dist[max_idx] = max_dist;
85+
// Erase the row index.
86+
row_pool.erase(
87+
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
88+
}
89+
}
90+
}
91+
92+
void Compute(const framework::ExecutionContext& context) const override {
93+
auto* dist_mat = context.Input<LoDTensor>("DistMat");
94+
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
95+
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
96+
97+
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
98+
99+
auto col = dist_mat->dims()[1];
100+
101+
int64_t n = dist_mat->lod().size() == 0UL
102+
? 1
103+
: static_cast<int64_t>(dist_mat->lod().back().size() - 1);
104+
if (dist_mat->lod().size()) {
105+
PADDLE_ENFORCE_EQ(dist_mat->lod().size(), 1UL,
106+
"Only support 1 level of LoD.");
107+
}
108+
match_indices->mutable_data<int>({n, col}, context.GetPlace());
109+
match_dist->mutable_data<T>({n, col}, context.GetPlace());
110+
111+
math::SetConstant<platform::CPUDeviceContext, int> iset;
112+
iset(dev_ctx, match_indices, static_cast<int>(-1));
113+
math::SetConstant<platform::CPUDeviceContext, T> tset;
114+
tset(dev_ctx, match_dist, static_cast<T>(0));
115+
116+
int* indices = match_indices->data<int>();
117+
T* dist = match_dist->data<T>();
118+
if (n == 1) {
119+
BipartiteMatch(*dist_mat, indices, dist);
120+
} else {
121+
auto lod = dist_mat->lod().back();
122+
for (size_t i = 0; i < lod.size() - 1; ++i) {
123+
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
124+
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
125+
}
126+
}
127+
}
128+
};
129+
130+
class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
131+
public:
132+
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
133+
: OpProtoAndCheckerMaker(proto, op_checker) {
134+
AddInput(
135+
"DistMat",
136+
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
137+
"[K, M]. It is pair-wise distance matrix between the entities "
138+
"represented by each row and each column. For example, assumed one "
139+
"entity is A with shape [K], another entity is B with shape [M]. The "
140+
"DistMat[i][j] is the distance between A[i] and B[j]. The bigger "
141+
"the distance is, the better macthing the pairs are. Please note, "
142+
"This tensor can contain LoD information to represent a batch of "
143+
"inputs. One instance of this batch can contain different numbers of "
144+
"entities.");
145+
AddOutput("ColToRowMatchIndices",
146+
"(Tensor) A 2-D Tensor with shape [N, M] in int type. "
147+
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
148+
"means B[j] does not match any entity in i-th instance. "
149+
"Otherwise, it means B[j] is matched to row "
150+
"ColToRowMatchIndices[i][j] in i-th instance. The row number of "
151+
"i-th instance is saved in ColToRowMatchIndices[i][j].");
152+
AddOutput("ColToRowMatchDis",
153+
"(Tensor) A 2-D Tensor with shape [N, M] in float type. "
154+
"N is batch size. If ColToRowMatchIndices[i][j] is -1, "
155+
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
156+
"ColToRowMatchIndices[i][j] = d, and the row offsets of each "
157+
"instance are called LoD. Then "
158+
"ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]");
159+
AddComment(R"DOC(
160+
This operator is a greedy bipartite matching algorithm, which is used to
161+
obtain the matching with the maximum distance based on the input
162+
distance matrix. For input 2D matrix, the bipartite matching algorithm can
163+
find the matched column for each row, also can find the matched row for
164+
each column. And this operator only calculate matched indices from column
165+
to row. For each instance, the number of matched indices is the number of
166+
of columns of the input ditance matrix.
167+
168+
There are two outputs to save matched indices and distance.
169+
A simple description, this algothrim matched the best (maximum distance)
170+
row entity to the column entity and the matched indices are not duplicated
171+
in each row of ColToRowMatchIndices. If the column entity is not matched
172+
any row entity, set -1 in ColToRowMatchIndices.
173+
174+
Please note that the input DistMat can be LoDTensor (with LoD) or Tensor.
175+
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
176+
If Tensor, the height of ColToRowMatchIndices is 1.
177+
178+
)DOC");
179+
}
180+
};
181+
182+
} // namespace operators
183+
} // namespace paddle
184+
185+
namespace ops = paddle::operators;
186+
REGISTER_OPERATOR(bipartite_match, ops::BipartiteMatchOp,
187+
ops::BipartiteMatchOpMaker,
188+
paddle::framework::EmptyGradOpMaker);
189+
REGISTER_OP_CPU_KERNEL(bipartite_match, ops::BipartiteMatchKernel<float>,
190+
ops::BipartiteMatchKernel<double>);
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
#Licensed under the Apache License, Version 2.0 (the "License");
4+
#you may not use this file except in compliance with the License.
5+
#You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
#Unless required by applicable law or agreed to in writing, software
10+
#distributed under the License is distributed on an "AS IS" BASIS,
11+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#See the License for the specific language governing permissions and
13+
#limitations under the License.
14+
import unittest
15+
import numpy as np
16+
from op_test import OpTest
17+
18+
19+
def bipartite_match(distance, match_indices, match_dis):
20+
"""Bipartite Matching algorithm.
21+
Arg:
22+
distance (numpy.array) : The distance of two entries with shape [M, N].
23+
match_indices (numpy.array): the matched indices from column to row
24+
with shape [1, N], it must be initialized to -1.
25+
match_dis (numpy.array): The matched distance from column to row
26+
with shape [1, N], it must be initialized to 0.
27+
"""
28+
match_pair = []
29+
row, col = distance.shape
30+
for i in range(row):
31+
for j in range(col):
32+
match_pair.append((i, j, distance[i][j]))
33+
34+
match_sorted = sorted(match_pair, key=lambda tup: tup[2], reverse=True)
35+
36+
row_indices = -1 * np.ones((row, ), dtype=np.int)
37+
38+
idx = 0
39+
for i, j, dis in match_sorted:
40+
if idx >= row:
41+
break
42+
if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0:
43+
match_indices[j] = i
44+
row_indices[i] = j
45+
match_dis[j] = dis
46+
idx += 1
47+
48+
49+
def batch_bipartite_match(distance, lod):
50+
"""Bipartite Matching algorithm for batch input.
51+
Arg:
52+
distance (numpy.array) : The distance of two entries with shape [M, N].
53+
lod (list of int): The offsets of each input in this batch.
54+
"""
55+
n = len(lod) - 1
56+
m = distance.shape[1]
57+
match_indices = -1 * np.ones((n, m), dtype=np.int)
58+
match_dis = np.zeros((n, m), dtype=np.float32)
59+
for i in range(len(lod) - 1):
60+
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
61+
match_dis[i, :])
62+
return match_indices, match_dis
63+
64+
65+
class TestBipartiteMatchOpForWithLoD(OpTest):
66+
def setUp(self):
67+
self.op_type = 'bipartite_match'
68+
lod = [[0, 5, 11, 23]]
69+
dis = np.random.random((23, 217)).astype('float32')
70+
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
71+
72+
self.inputs = {'DistMat': (dis, lod)}
73+
self.outputs = {
74+
'ColToRowMatchIndices': (match_indices),
75+
'ColToRowMatchDis': (match_dis),
76+
}
77+
78+
def test_check_output(self):
79+
self.check_output()
80+
81+
82+
class TestBipartiteMatchOpWithoutLoD(OpTest):
83+
def setUp(self):
84+
self.op_type = 'bipartite_match'
85+
lod = [[0, 8]]
86+
dis = np.random.random((8, 17)).astype('float32')
87+
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
88+
89+
self.inputs = {'DistMat': dis}
90+
self.outputs = {
91+
'ColToRowMatchIndices': (match_indices),
92+
'ColToRowMatchDis': (match_dis),
93+
}
94+
95+
def test_check_output(self):
96+
self.check_output()
97+
98+
99+
if __name__ == '__main__':
100+
unittest.main()

0 commit comments

Comments
 (0)