@@ -94,6 +94,38 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
94
94
}
95
95
}
96
96
97
+ void ArgMaxMatch (const Tensor& dist, int * match_indices, T* match_dist,
98
+ T overlap_threshold) const {
99
+ constexpr T kEPS = static_cast <T>(1e-6 );
100
+ int64_t row = dist.dims ()[0 ];
101
+ int64_t col = dist.dims ()[1 ];
102
+ auto * dist_data = dist.data <T>();
103
+ for (int64_t j = 0 ; j < col; ++j) {
104
+ if (match_indices[j] != -1 ) {
105
+ // the j-th column has been matched to one entity.
106
+ continue ;
107
+ }
108
+ int max_row_idx = -1 ;
109
+ T max_dist = -1 ;
110
+ for (int i = 0 ; i < row; ++i) {
111
+ T dist = dist_data[i * col + j];
112
+ if (dist < kEPS ) {
113
+ // distance is 0 between m-th row and j-th column
114
+ continue ;
115
+ }
116
+ if (dist >= overlap_threshold && dist > max_dist) {
117
+ max_row_idx = i;
118
+ max_dist = dist;
119
+ }
120
+ }
121
+ if (max_row_idx != -1 ) {
122
+ PADDLE_ENFORCE_EQ (match_indices[j], -1 );
123
+ match_indices[j] = max_row_idx;
124
+ match_dist[j] = max_dist;
125
+ }
126
+ }
127
+ }
128
+
97
129
void Compute (const framework::ExecutionContext& context) const override {
98
130
auto * dist_mat = context.Input <LoDTensor>(" DistMat" );
99
131
auto * match_indices = context.Output <Tensor>(" ColToRowMatchIndices" );
@@ -120,13 +152,21 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
120
152
121
153
int * indices = match_indices->data <int >();
122
154
T* dist = match_dist->data <T>();
155
+ auto type = context.Attr <std::string>(" match_type" );
156
+ auto threshold = context.Attr <float >(" dist_threshold" );
123
157
if (n == 1 ) {
124
158
BipartiteMatch (*dist_mat, indices, dist);
159
+ if (type == " per_prediction" ) {
160
+ ArgMaxMatch (*dist_mat, indices, dist, threshold);
161
+ }
125
162
} else {
126
163
auto lod = dist_mat->lod ().back ();
127
164
for (size_t i = 0 ; i < lod.size () - 1 ; ++i) {
128
165
Tensor one_ins = dist_mat->Slice (lod[i], lod[i + 1 ]);
129
166
BipartiteMatch (one_ins, indices + i * col, dist + i * col);
167
+ if (type == " per_prediction" ) {
168
+ ArgMaxMatch (one_ins, indices + i * col, dist + i * col, threshold);
169
+ }
130
170
}
131
171
}
132
172
}
@@ -147,6 +187,19 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
147
187
" This tensor can contain LoD information to represent a batch of "
148
188
" inputs. One instance of this batch can contain different numbers of "
149
189
" entities." );
190
+ AddAttr<std::string>(
191
+ " match_type" ,
192
+ " (string, defalut: per_prediction) "
193
+ " The type of matching method, should be 'bipartite' or "
194
+ " 'per_prediction', 'bipartite' by defalut." )
195
+ .SetDefault (" bipartite" )
196
+ .InEnum ({" bipartite" , " per_prediction" });
197
+ AddAttr<float >(
198
+ " dist_threshold" ,
199
+ " (float, defalut: 0.5) "
200
+ " If `match_type` is 'per_prediction', this threshold is to determine "
201
+ " the extra matching bboxes based on the maximum distance." )
202
+ .SetDefault (0.5 );
150
203
AddOutput (" ColToRowMatchIndices" ,
151
204
" (Tensor) A 2-D Tensor with shape [N, M] in int type. "
152
205
" N is the batch size. If ColToRowMatchIndices[i][j] is -1, it "
@@ -168,10 +221,10 @@ distance matrix. For input 2D matrix, the bipartite matching algorithm can
168
221
find the matched column for each row, also can find the matched row for
169
222
each column. And this operator only calculate matched indices from column
170
223
to row. For each instance, the number of matched indices is the number of
171
- of columns of the input ditance matrix.
224
+ of columns of the input distance matrix.
172
225
173
226
There are two outputs to save matched indices and distance.
174
- A simple description, this algothrim matched the best (maximum distance)
227
+ A simple description, this algorithm matched the best (maximum distance)
175
228
row entity to the column entity and the matched indices are not duplicated
176
229
in each row of ColToRowMatchIndices. If the column entity is not matched
177
230
any row entity, set -1 in ColToRowMatchIndices.
0 commit comments