Skip to content

Commit 0790868

Browse files
committed
Update some comments and add more check.
1 parent c2edcde commit 0790868

File tree

1 file changed

+42
-30
lines changed

1 file changed

+42
-30
lines changed

paddle/operators/bipartite_match_op.cc

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace operators {
2121
using Tensor = framework::Tensor;
2222
using LoDTensor = framework::LoDTensor;
2323

24+
constexpr char kEPS = 1e-6;
25+
2426
class BipartiteMatchOp : public framework::OperatorWithKernel {
2527
public:
2628
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -41,34 +43,35 @@ template <typename T>
4143
class BipartiteMatchKernel : public framework::OpKernel<T> {
4244
public:
4345
// The match_indices must be initialized to -1 at first.
44-
// The match_dis must be initialized to 0 at first.
45-
void BipartiteMatch(const Tensor& dis, int* match_indices,
46-
T* match_dis) const {
47-
int64_t row = dis.dims()[0];
48-
int64_t col = dis.dims()[1];
49-
auto* dis_data = dis.data<T>();
46+
// The match_dist must be initialized to 0 at first.
47+
void BipartiteMatch(const Tensor& dist, int* match_indices,
48+
T* match_dist) const {
49+
PADDLE_ENFORCE_EQ(dist.dims().size(), 2, "The rank of dist must be 2.");
50+
int64_t row = dist.dims()[0];
51+
int64_t col = dist.dims()[1];
52+
auto* dist_data = dist.data<T>();
5053
std::vector<int> row_pool;
5154
for (int i = 0; i < row; ++i) {
5255
row_pool.push_back(i);
5356
}
5457
while (row_pool.size() > 0) {
5558
int max_idx = -1;
5659
int max_row_idx = -1;
57-
T max_dis = -1;
60+
T max_dist = -1;
5861
for (int64_t j = 0; j < col; ++j) {
5962
if (match_indices[j] != -1) {
6063
continue;
6164
}
6265
for (int k = 0; k < row_pool.size(); ++k) {
6366
int m = row_pool[k];
6467
// distance is 0 between m-th row and j-th column
65-
if (dis_data[m * col + j] < 1e-6) {
68+
if (dist_data[m * col + j] < kEPS) {
6669
continue;
6770
}
68-
if (dis_data[m * col + j] > max_dis) {
71+
if (dist_data[m * col + j] > max_dist) {
6972
max_idx = j;
7073
max_row_idx = m;
71-
max_dis = dis_data[m * col + j];
74+
max_dist = dist_data[m * col + j];
7275
}
7376
}
7477
}
@@ -78,7 +81,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
7881
} else {
7982
PADDLE_ENFORCE_EQ(match_indices[max_idx], -1);
8083
match_indices[max_idx] = max_row_idx;
81-
match_dis[max_idx] = max_dis;
84+
match_dist[max_idx] = max_dist;
8285
// Erase the row index.
8386
row_pool.erase(
8487
std::find(row_pool.begin(), row_pool.end(), max_row_idx));
@@ -87,34 +90,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
8790
}
8891

8992
void Compute(const framework::ExecutionContext& context) const override {
90-
auto* dis_mat = context.Input<LoDTensor>("DisMat");
93+
auto* dist_mat = context.Input<LoDTensor>("DisMat");
9194
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
92-
auto* match_dis = context.Output<Tensor>("ColToRowMatchDis");
95+
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
9396

9497
auto& dev_ctx = context.device_context<platform::CPUDeviceContext>();
9598

96-
auto col = dis_mat->dims()[1];
99+
auto col = dist_mat->dims()[1];
97100

98-
int64_t n = dis_mat->lod().size() == 0
101+
int64_t n = dist_mat->lod().size() == 0UL
99102
? 1
100-
: static_cast<int64_t>(dis_mat->lod().back().size() - 1);
103+
: static_cast<int64_t>(dist_mat->lod().back().size() - 1);
104+
if (dist_mat->lod().size()) {
105+
PADDLE_ENFORCE_EQ(dist_mat->lod().size(), 1UL,
106+
"Only support 1 level of LoD.");
107+
}
101108
match_indices->mutable_data<int>({n, col}, context.GetPlace());
102-
match_dis->mutable_data<T>({n, col}, context.GetPlace());
109+
match_dist->mutable_data<T>({n, col}, context.GetPlace());
103110

104111
math::SetConstant<platform::CPUDeviceContext, int> iset;
105112
iset(dev_ctx, match_indices, static_cast<int>(-1));
106113
math::SetConstant<platform::CPUDeviceContext, T> tset;
107-
tset(dev_ctx, match_dis, static_cast<T>(0));
114+
tset(dev_ctx, match_dist, static_cast<T>(0));
108115

109116
int* indices = match_indices->data<int>();
110-
T* dis = match_dis->data<T>();
117+
T* dist = match_dist->data<T>();
111118
if (n == 1) {
112-
BipartiteMatch(*dis_mat, indices, dis);
119+
BipartiteMatch(*dist_mat, indices, dist);
113120
} else {
114-
auto lod = dis_mat->lod().back();
121+
auto lod = dist_mat->lod().back();
115122
for (size_t i = 0; i < lod.size() - 1; ++i) {
116-
Tensor one_ins = dis_mat->Slice(lod[i], lod[i + 1]);
117-
BipartiteMatch(one_ins, indices + i * col, dis + i * col);
123+
Tensor one_ins = dist_mat->Slice(lod[i], lod[i + 1]);
124+
BipartiteMatch(one_ins, indices + i * col, dist + i * col);
118125
}
119126
}
120127
}
@@ -131,7 +138,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
131138
"represented by each row and each column. For example, assumed one "
132139
"entity is A with shape [K], another entity is B with shape [M]. The "
133140
"DisMat[i][j] is the distance between A[i] and B[j]. The bigger "
134-
"the distance is, the more similar the pairs are. Please note, "
141+
"the distance is, the better macthing the pairs are. Please note, "
135142
"This tensor can contain LoD information to represent a batch of "
136143
"inputs. One instance of this batch can contain different numbers of "
137144
"entities.");
@@ -140,20 +147,25 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
140147
"N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
141148
"means B[j] does not match any entity in i-th instance. "
142149
"Otherwise, it means B[j] is matched to row "
143-
"RowToColMatchIndices[i][j] in i-th instance. The row number of "
144-
"i-th instance is saved in RowToColMatchIndices[i][j].");
150+
"ColToRowMatchIndices[i][j] in i-th instance. The row number of "
151+
"i-th instance is saved in ColToRowMatchIndices[i][j].");
145152
AddOutput("ColToRowMatchDis",
146153
"(Tensor) A 2-D Tensor with shape [N, M] in float type. "
147154
"N is batch size. If ColToRowMatchIndices[i][j] is -1, "
148155
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
149-
"RowToColMatchIndices[i][j] = d, and the row offsets of each "
156+
"ColToRowMatchIndices[i][j] = d, and the row offsets of each "
150157
"instance are called LoD. Then "
151158
"ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]");
152159
AddComment(R"DOC(
153160
This operator is a greedy bipartite matching algorithm, which is used to
154-
obtain the matching with the (greedy) maximum distance based on the input
155-
distance matrix. There are two outputs to save matched indices and distance.
156-
And this operator only calculate matched indices from column to row.
161+
obtain the matching with the maximum distance based on the input
162+
distance matrix. For input 2D matrix, the bipartite matching algorithm can
163+
find the matched column for each row, also can find the matched row for
164+
each column. And this operator only calculate matched indices from column
165+
to row. For each instance, the number of matched indices is the number of
166+
of columns of the input ditance matrix.
167+
168+
There are two outputs to save matched indices and distance.
157169
A simple description, this algothrim matched the best (maximum distance)
158170
row entity to the column entity and the matched indices are not duplicated
159171
in each row of ColToRowMatchIndices. If the column entity is not matched

0 commit comments

Comments
 (0)