Skip to content

Commit bf33b19

Browse files
committed
Add bipartite matching operator and unit testing.
1 parent 38c6105 commit bf33b19

File tree

2 files changed

+278
-0
lines changed

2 files changed

+278
-0
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/framework/op_registry.h"
16+
#include "paddle/operators/math/math_function.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
using LoDTensor = framework::LoDTensor;
23+
24+
class BipartiteMatchOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
28+
void InferShape(framework::InferShapeContext* ctx) const override {
29+
PADDLE_ENFORCE(ctx->HasInput("DisMat"),
30+
"Input(DisMat) of BipartiteMatch should not be null.");
31+
32+
auto dims = ctx->GetInputDim("DisMat");
33+
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DisMat) must be 2.");
34+
35+
ctx->SetOutputDim("ColToRowMatchIndices", dims);
36+
ctx->SetOutputDim("ColToRowMatchDis", dims);
37+
}
38+
};
39+
40+
template <typename T>
41+
class BipartiteMatchKernel : public framework::OpKernel<T> {
42+
public:
43+
// The match_indices must be initialized to -1 at first.
44+
// The match_dis must be initialized to 0 at first.
45+
void BipartiteMatch(const Tensor& dis, int* match_indices,
46+
T* match_dis) const {
47+
int64_t row = dis.dims()[0];
48+
int64_t col = dis.dims()[1];
49+
auto* dis_data = dis.data<T>();
50+
std::vector<int> row_pool;
51+
for (int i = 0; i < row; ++i) {
52+
row_pool.push_back(i);
53+
}
54+
while (row_pool.size() > 0) {
55+
int max_idx = -1;
56+
int max_row_idx = -1;
57+
T max_dis = -1;
58+
for (int64_t j = 0; j < col; ++j) {
59+
if (match_indices[j] != -1) {
60+
continue;
61+
}
62+
for (int k = 0; k < row_pool.size(); ++k) {
63+
int m = row_pool[k];
64+
// distance is 0 between m-th row and j-th column
65+
if (dis_data[m * col + j] < 1e-6) {
66+
continue;
67+
}
68+
if (dis_data[m * col + j] > max_dis) {
69+
max_idx = j;
70+
max_row_idx = m;
71+
max_dis = dis_data[m * col + j];
72+
}
73+
}
74+
}
75+
if (max_idx == -1) {
76+
// Cannot find good match.
77+
break;
78+
} else {
79+
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
80+
match_indices[max_idx] = max_row_idx;
81+
match_dis[max_idx] = max_dis;
82+
// Erase the row index.
83+
row_pool.erase(
84+
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
85+
}
86+
}
87+
}
88+
89+
void Compute(const framework::ExecutionContext& context) const override {
90+
auto* dis_mat = context.Input<LoDTensor>("DisMat");
91+
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
92+
auto* match_dis = context.Output<Tensor>("ColToRowMatchDis");
93+
94+
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
95+
96+
auto col = dis_mat->dims()[1];
97+
98+
int64_t n = dis_mat->lod().size() == 0
99+
? 1
100+
: static_cast<int64_t>(dis_mat->lod().back().size() - 1);
101+
match_indices->mutable_data<int>({n, col}, context.GetPlace());
102+
match_dis->mutable_data<T>({n, col}, context.GetPlace());
103+
104+
math::SetConstant<platform::CPUDeviceContext, int> iset;
105+
iset(dev_ctx, match_indices, static_cast<int>(-1));
106+
math::SetConstant<platform::CPUDeviceContext, T> tset;
107+
tset(dev_ctx, match_dis, static_cast<T>(0));
108+
109+
int* indices = match_indices->data<int>();
110+
T* dis = match_dis->data<T>();
111+
if (n == 1) {
112+
BipartiteMatch(*dis_mat, indices, dis);
113+
} else {
114+
auto lod = dis_mat->lod().back();
115+
for (size_t i = 0; i < lod.size() - 1; ++i) {
116+
Tensor one_ins = dis_mat->Slice(lod[i], lod[i + 1]);
117+
BipartiteMatch(one_ins, indices + i * col, dis + i * col);
118+
}
119+
}
120+
}
121+
};
122+
123+
class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
124+
public:
125+
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
126+
: OpProtoAndCheckerMaker(proto, op_checker) {
127+
AddInput(
128+
"DisMat",
129+
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
130+
"[K, M]. It is pair-wise distance matrix between the entities "
131+
"represented by each row and each column. For example, assumed one "
132+
"entity is A with shape [K], another entity is B with shape [M]. The "
133+
"DisMat[i][j] is the distance between A[i] and B[j]. The bigger "
134+
"the distance is, the more similar the pairs are. Please note, "
135+
"This tensor can contain LoD information to represent a batch of "
136+
"inputs. One instance of this batch can contain different numbers of "
137+
"entities.");
138+
AddOutput("ColToRowMatchIndices",
139+
"(Tensor) A 2-D Tensor with shape [N, M] in int type. "
140+
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
141+
"means B[j] does not match any entity in i-th instance. "
142+
"Otherwise, it means B[j] is matched to row "
143+
"RowToColMatchIndices[i][j] in i-th instance. The row number of "
144+
"i-th instance is saved in RowToColMatchIndices[i][j].");
145+
AddOutput("ColToRowMatchDis",
146+
"(Tensor) A 2-D Tensor with shape [N, M] in float type. "
147+
"N is batch size. If ColToRowMatchIndices[i][j] is -1, "
148+
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
149+
"RowToColMatchIndices[i][j] = d, and the row offsets of each "
150+
"instance are called LoD. Then "
151+
"ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]");
152+
AddComment(R"DOC(
153+
This operator is a greedy bipartite matching algorithm, which is used to
154+
obtain the matching with the (greedy) maximum distance based on the input
155+
distance matrix. There are two outputs to save matched indices and distance.
156+
And this operator only calculate matched indices from column to row.
157+
A simple description, this algothrim matched the best (maximum distance)
158+
row entity to the column entity and the matched indices are not duplicated
159+
in each row of ColToRowMatchIndices. If the column entity is not matched
160+
any row entity, set -1 in ColToRowMatchIndices.
161+
162+
Please note that the input DisMat can be LoDTensor (with LoD) or Tensor.
163+
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
164+
If Tensor, the height of ColToRowMatchIndices is 1.
165+
166+
)DOC");
167+
}
168+
};
169+
170+
} // namespace operators
171+
} // namespace paddle
172+
173+
namespace ops = paddle::operators;
174+
REGISTER_OPERATOR(bipartite_match, ops::BipartiteMatchOp,
175+
ops::BipartiteMatchOpMaker,
176+
paddle::framework::EmptyGradOpMaker);
177+
REGISTER_OP_CPU_KERNEL(bipartite_match, ops::BipartiteMatchKernel<float>,
178+
ops::BipartiteMatchKernel<double>);
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
#Licensed under the Apache License, Version 2.0 (the "License");
4+
#you may not use this file except in compliance with the License.
5+
#You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
#Unless required by applicable law or agreed to in writing, software
10+
#distributed under the License is distributed on an "AS IS" BASIS,
11+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#See the License for the specific language governing permissions and
13+
#limitations under the License.
14+
import unittest
15+
import numpy as np
16+
from op_test import OpTest
17+
18+
19+
def bipartite_match(distance, match_indices, match_dis):
20+
"""Bipartite Matching algorithm.
21+
Arg:
22+
distance (numpy.array) : The distance of two entries with shape [M, N].
23+
match_indices (numpy.array): the matched indices from column to row
24+
with shape [1, N], it must be initialized to -1.
25+
match_dis (numpy.array): The matched distance from column to row
26+
with shape [1, N], it must be initialized to 0.
27+
"""
28+
match_pair = []
29+
row, col = distance.shape
30+
for i in range(row):
31+
for j in range(col):
32+
match_pair.append((i, j, distance[i][j]))
33+
34+
match_sorted = sorted(match_pair, key=lambda tup: tup[2], reverse=True)
35+
36+
row_indices = -1 * np.ones((row, ), dtype=np.int)
37+
38+
idx = 0
39+
for i, j, dis in match_sorted:
40+
if idx >= row:
41+
break
42+
if match_indices[j] == -1 and row_indices[i] == -1 and dis > 0:
43+
match_indices[j] = i
44+
row_indices[i] = j
45+
match_dis[j] = dis
46+
idx += 1
47+
48+
49+
def batch_bipartite_match(distance, lod):
50+
"""Bipartite Matching algorithm for batch input.
51+
Arg:
52+
distance (numpy.array) : The distance of two entries with shape [M, N].
53+
lod (list of int): The offsets of each input in this batch.
54+
"""
55+
n = len(lod) - 1
56+
m = distance.shape[1]
57+
match_indices = -1 * np.ones((n, m), dtype=np.int)
58+
match_dis = np.zeros((n, m), dtype=np.float32)
59+
for i in range(len(lod) - 1):
60+
bipartite_match(distance[lod[i]:lod[i + 1], :], match_indices[i, :],
61+
match_dis[i, :])
62+
return match_indices, match_dis
63+
64+
65+
class TestBipartiteMatchOpForWithLoD(OpTest):
66+
def setUp(self):
67+
self.op_type = 'bipartite_match'
68+
lod = [[0, 5, 11, 23]]
69+
dis = np.random.random((23, 217)).astype('float32')
70+
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
71+
72+
self.inputs = {'DisMat': (dis, lod)}
73+
self.outputs = {
74+
'ColToRowMatchIndices': (match_indices),
75+
'ColToRowMatchDis': (match_dis),
76+
}
77+
78+
def test_check_output(self):
79+
self.check_output()
80+
81+
82+
class TestBipartiteMatchOpWithoutLoD(OpTest):
83+
def setUp(self):
84+
self.op_type = 'bipartite_match'
85+
lod = [[0, 8]]
86+
dis = np.random.random((8, 17)).astype('float32')
87+
match_indices, match_dis = batch_bipartite_match(dis, lod[0])
88+
89+
self.inputs = {'DisMat': dis}
90+
self.outputs = {
91+
'ColToRowMatchIndices': (match_indices),
92+
'ColToRowMatchDis': (match_dis),
93+
}
94+
95+
def test_check_output(self):
96+
self.check_output()
97+
98+
99+
if __name__ == '__main__':
100+
unittest.main()

0 commit comments

Comments
 (0)