@@ -21,6 +21,8 @@ namespace operators {
21
21
using Tensor = framework::Tensor;
22
22
using LoDTensor = framework::LoDTensor;
23
23
24
+ constexpr char kEPS = 1e-6 ;
25
+
24
26
class BipartiteMatchOp : public framework ::OperatorWithKernel {
25
27
public:
26
28
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -41,34 +43,35 @@ template <typename T>
41
43
class BipartiteMatchKernel : public framework ::OpKernel<T> {
42
44
public:
43
45
// The match_indices must be initialized to -1 at first.
44
- // The match_dis must be initialized to 0 at first.
45
- void BipartiteMatch (const Tensor& dis, int * match_indices,
46
- T* match_dis) const {
47
- int64_t row = dis.dims ()[0 ];
48
- int64_t col = dis.dims ()[1 ];
49
- auto * dis_data = dis.data <T>();
46
+ // The match_dist must be initialized to 0 at first.
47
+ void BipartiteMatch (const Tensor& dist, int * match_indices,
48
+ T* match_dist) const {
49
+ PADDLE_ENFORCE_EQ (dist.dims ().size (), 2 , " The rank of dist must be 2." );
50
+ int64_t row = dist.dims ()[0 ];
51
+ int64_t col = dist.dims ()[1 ];
52
+ auto * dist_data = dist.data <T>();
50
53
std::vector<int > row_pool;
51
54
for (int i = 0 ; i < row; ++i) {
52
55
row_pool.push_back (i);
53
56
}
54
57
while (row_pool.size () > 0 ) {
55
58
int max_idx = -1 ;
56
59
int max_row_idx = -1 ;
57
- T max_dis = -1 ;
60
+ T max_dist = -1 ;
58
61
for (int64_t j = 0 ; j < col; ++j) {
59
62
if (match_indices[j] != -1 ) {
60
63
continue ;
61
64
}
62
65
for (int k = 0 ; k < row_pool.size (); ++k) {
63
66
int m = row_pool[k];
64
67
// distance is 0 between m-th row and j-th column
65
- if (dis_data [m * col + j] < 1e-6 ) {
68
+ if (dist_data [m * col + j] < kEPS ) {
66
69
continue ;
67
70
}
68
- if (dis_data [m * col + j] > max_dis ) {
71
+ if (dist_data [m * col + j] > max_dist ) {
69
72
max_idx = j;
70
73
max_row_idx = m;
71
- max_dis = dis_data [m * col + j];
74
+ max_dist = dist_data [m * col + j];
72
75
}
73
76
}
74
77
}
@@ -78,7 +81,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
78
81
} else {
79
82
PADDLE_ENFORCE_EQ (match_indices[max_idx], -1 );
80
83
match_indices[max_idx] = max_row_idx;
81
- match_dis [max_idx] = max_dis ;
84
+ match_dist [max_idx] = max_dist ;
82
85
// Erase the row index.
83
86
row_pool.erase (
84
87
std::find (row_pool.begin (), row_pool.end (), max_row_idx));
@@ -87,34 +90,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
87
90
}
88
91
89
92
void Compute (const framework::ExecutionContext& context) const override {
90
- auto * dis_mat = context.Input <LoDTensor>(" DisMat" );
93
+ auto * dist_mat = context.Input <LoDTensor>(" DisMat" );
91
94
auto * match_indices = context.Output <Tensor>(" ColToRowMatchIndices" );
92
- auto * match_dis = context.Output <Tensor>(" ColToRowMatchDis" );
95
+ auto * match_dist = context.Output <Tensor>(" ColToRowMatchDis" );
93
96
94
97
auto & dev_ctx = context.device_context <platform::CPUDeviceContext>();
95
98
96
- auto col = dis_mat ->dims ()[1 ];
99
+ auto col = dist_mat ->dims ()[1 ];
97
100
98
- int64_t n = dis_mat ->lod ().size () == 0
101
+ int64_t n = dist_mat ->lod ().size () == 0UL
99
102
? 1
100
- : static_cast <int64_t >(dis_mat->lod ().back ().size () - 1 );
103
+ : static_cast <int64_t >(dist_mat->lod ().back ().size () - 1 );
104
+ if (dist_mat->lod ().size ()) {
105
+ PADDLE_ENFORCE_EQ (dist_mat->lod ().size (), 1UL ,
106
+ " Only support 1 level of LoD." );
107
+ }
101
108
match_indices->mutable_data <int >({n, col}, context.GetPlace ());
102
- match_dis ->mutable_data <T>({n, col}, context.GetPlace ());
109
+ match_dist ->mutable_data <T>({n, col}, context.GetPlace ());
103
110
104
111
math::SetConstant<platform::CPUDeviceContext, int > iset;
105
112
iset (dev_ctx, match_indices, static_cast <int >(-1 ));
106
113
math::SetConstant<platform::CPUDeviceContext, T> tset;
107
- tset (dev_ctx, match_dis , static_cast <T>(0 ));
114
+ tset (dev_ctx, match_dist , static_cast <T>(0 ));
108
115
109
116
int * indices = match_indices->data <int >();
110
- T* dis = match_dis ->data <T>();
117
+ T* dist = match_dist ->data <T>();
111
118
if (n == 1 ) {
112
- BipartiteMatch (*dis_mat , indices, dis );
119
+ BipartiteMatch (*dist_mat , indices, dist );
113
120
} else {
114
- auto lod = dis_mat ->lod ().back ();
121
+ auto lod = dist_mat ->lod ().back ();
115
122
for (size_t i = 0 ; i < lod.size () - 1 ; ++i) {
116
- Tensor one_ins = dis_mat ->Slice (lod[i], lod[i + 1 ]);
117
- BipartiteMatch (one_ins, indices + i * col, dis + i * col);
123
+ Tensor one_ins = dist_mat ->Slice (lod[i], lod[i + 1 ]);
124
+ BipartiteMatch (one_ins, indices + i * col, dist + i * col);
118
125
}
119
126
}
120
127
}
@@ -131,7 +138,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
131
138
" represented by each row and each column. For example, assumed one "
132
139
" entity is A with shape [K], another entity is B with shape [M]. The "
133
140
" DisMat[i][j] is the distance between A[i] and B[j]. The bigger "
134
- " the distance is, the more similar the pairs are. Please note, "
141
+ " the distance is, the better macthing the pairs are. Please note, "
135
142
" This tensor can contain LoD information to represent a batch of "
136
143
" inputs. One instance of this batch can contain different numbers of "
137
144
" entities." );
@@ -140,20 +147,25 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
140
147
" N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
141
148
" means B[j] does not match any entity in i-th instance. "
142
149
" Otherwise, it means B[j] is matched to row "
143
- " RowToColMatchIndices [i][j] in i-th instance. The row number of "
144
- " i-th instance is saved in RowToColMatchIndices [i][j]." );
150
+ " ColToRowMatchIndices [i][j] in i-th instance. The row number of "
151
+ " i-th instance is saved in ColToRowMatchIndices [i][j]." );
145
152
AddOutput (" ColToRowMatchDis" ,
146
153
" (Tensor) A 2-D Tensor with shape [N, M] in float type. "
147
154
" N is batch size. If ColToRowMatchIndices[i][j] is -1, "
148
155
" ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
149
- " RowToColMatchIndices [i][j] = d, and the row offsets of each "
156
+ " ColToRowMatchIndices [i][j] = d, and the row offsets of each "
150
157
" instance are called LoD. Then "
151
158
" ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]" );
152
159
AddComment (R"DOC(
153
160
This operator is a greedy bipartite matching algorithm, which is used to
154
- obtain the matching with the (greedy) maximum distance based on the input
155
- distance matrix. There are two outputs to save matched indices and distance.
156
- And this operator only calculate matched indices from column to row.
161
+ obtain the matching with the maximum distance based on the input
162
+ distance matrix. For input 2D matrix, the bipartite matching algorithm can
163
+ find the matched column for each row, also can find the matched row for
164
+ each column. And this operator only calculate matched indices from column
165
+ to row. For each instance, the number of matched indices is the number of
166
+ of columns of the input ditance matrix.
167
+
168
+ There are two outputs to save matched indices and distance.
157
169
A simple description, this algothrim matched the best (maximum distance)
158
170
row entity to the column entity and the matched indices are not duplicated
159
171
in each row of ColToRowMatchIndices. If the column entity is not matched
0 commit comments