@@ -28,11 +28,11 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
28
28
using framework::OperatorWithKernel::OperatorWithKernel;
29
29
30
30
void InferShape (framework::InferShapeContext* ctx) const override {
31
- PADDLE_ENFORCE (ctx->HasInput (" DisMat " ),
32
- " Input(DisMat ) of BipartiteMatch should not be null." );
31
+ PADDLE_ENFORCE (ctx->HasInput (" DistMat " ),
32
+ " Input(DistMat ) of BipartiteMatch should not be null." );
33
33
34
- auto dims = ctx->GetInputDim (" DisMat " );
35
- PADDLE_ENFORCE_EQ (dims.size (), 2 , " The rank of Input(DisMat ) must be 2." );
34
+ auto dims = ctx->GetInputDim (" DistMat " );
35
+ PADDLE_ENFORCE_EQ (dims.size (), 2 , " The rank of Input(DistMat ) must be 2." );
36
36
37
37
ctx->SetOutputDim (" ColToRowMatchIndices" , dims);
38
38
ctx->SetOutputDim (" ColToRowMatchDis" , dims);
@@ -90,7 +90,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
90
90
}
91
91
92
92
void Compute (const framework::ExecutionContext& context) const override {
93
- auto * dist_mat = context.Input <LoDTensor>(" DisMat " );
93
+ auto * dist_mat = context.Input <LoDTensor>(" DistMat " );
94
94
auto * match_indices = context.Output <Tensor>(" ColToRowMatchIndices" );
95
95
auto * match_dist = context.Output <Tensor>(" ColToRowMatchDis" );
96
96
@@ -132,12 +132,12 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
132
132
BipartiteMatchOpMaker (OpProto* proto, OpAttrChecker* op_checker)
133
133
: OpProtoAndCheckerMaker(proto, op_checker) {
134
134
AddInput (
135
- " DisMat " ,
135
+ " DistMat " ,
136
136
" (LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
137
137
" [K, M]. It is pair-wise distance matrix between the entities "
138
138
" represented by each row and each column. For example, assumed one "
139
139
" entity is A with shape [K], another entity is B with shape [M]. The "
140
- " DisMat [i][j] is the distance between A[i] and B[j]. The bigger "
140
+ " DistMat [i][j] is the distance between A[i] and B[j]. The bigger "
141
141
" the distance is, the better macthing the pairs are. Please note, "
142
142
" This tensor can contain LoD information to represent a batch of "
143
143
" inputs. One instance of this batch can contain different numbers of "
@@ -155,7 +155,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
155
155
" ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
156
156
" ColToRowMatchIndices[i][j] = d, and the row offsets of each "
157
157
" instance are called LoD. Then "
158
- " ColToRowMatchDis[i][j] = DisMat [d+LoD[i]][j]" );
158
+ " ColToRowMatchDis[i][j] = DistMat [d+LoD[i]][j]" );
159
159
AddComment (R"DOC(
160
160
This operator is a greedy bipartite matching algorithm, which is used to
161
161
obtain the matching with the maximum distance based on the input
@@ -171,7 +171,7 @@ row entity to the column entity and the matched indices are not duplicated
171
171
in each row of ColToRowMatchIndices. If the column entity is not matched
172
172
any row entity, set -1 in ColToRowMatchIndices.
173
173
174
- Please note that the input DisMat can be LoDTensor (with LoD) or Tensor.
174
+ Please note that the input DistMat can be LoDTensor (with LoD) or Tensor.
175
175
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
176
176
If Tensor, the height of ColToRowMatchIndices is 1.
177
177
0 commit comments