Skip to content

Commit 530df1b

Browse files
committed
Fix the naming.
1 parent 0790868 commit 530df1b

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

paddle/operators/bipartite_match_op.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
2828
using framework::OperatorWithKernel::OperatorWithKernel;
2929

3030
void InferShape(framework::InferShapeContext* ctx) const override {
31-
PADDLE_ENFORCE(ctx->HasInput("DisMat"),
32-
"Input(DisMat) of BipartiteMatch should not be null.");
31+
PADDLE_ENFORCE(ctx->HasInput("DistMat"),
32+
"Input(DistMat) of BipartiteMatch should not be null.");
3333

34-
auto dims = ctx->GetInputDim("DisMat");
35-
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DisMat) must be 2.");
34+
auto dims = ctx->GetInputDim("DistMat");
35+
PADDLE_ENFORCE_EQ(dims.size(), 2, "The rank of Input(DistMat) must be 2.");
3636

3737
ctx->SetOutputDim("ColToRowMatchIndices", dims);
3838
ctx->SetOutputDim("ColToRowMatchDis", dims);
@@ -90,7 +90,7 @@ class BipartiteMatchKernel : public framework::OpKernel<T> {
9090
}
9191

9292
void Compute(const framework::ExecutionContext& context) const override {
93-
auto* dist_mat = context.Input<LoDTensor>("DisMat");
93+
auto* dist_mat = context.Input<LoDTensor>("DistMat");
9494
auto* match_indices = context.Output<Tensor>("ColToRowMatchIndices");
9595
auto* match_dist = context.Output<Tensor>("ColToRowMatchDis");
9696

@@ -132,12 +132,12 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
132132
BipartiteMatchOpMaker(OpProto* proto, OpAttrChecker* op_checker)
133133
: OpProtoAndCheckerMaker(proto, op_checker) {
134134
AddInput(
135-
"DisMat",
135+
"DistMat",
136136
"(LoDTensor or Tensor) this input is a 2-D LoDTensor with shape "
137137
"[K, M]. It is pair-wise distance matrix between the entities "
138138
"represented by each row and each column. For example, assumed one "
139139
"entity is A with shape [K], another entity is B with shape [M]. The "
140-
"DisMat[i][j] is the distance between A[i] and B[j]. The bigger "
140+
"DistMat[i][j] is the distance between A[i] and B[j]. The bigger "
141141
"the distance is, the better macthing the pairs are. Please note, "
142142
"This tensor can contain LoD information to represent a batch of "
143143
"inputs. One instance of this batch can contain different numbers of "
@@ -155,7 +155,7 @@ class BipartiteMatchOpMaker : public framework::OpProtoAndCheckerMaker {
155155
"ColToRowMatchDis[i][j] is also -1.0. Otherwise, assumed "
156156
"ColToRowMatchIndices[i][j] = d, and the row offsets of each "
157157
"instance are called LoD. Then "
158-
"ColToRowMatchDis[i][j] = DisMat[d+LoD[i]][j]");
158+
"ColToRowMatchDis[i][j] = DistMat[d+LoD[i]][j]");
159159
AddComment(R"DOC(
160160
This operator is a greedy bipartite matching algorithm, which is used to
161161
obtain the matching with the maximum distance based on the input
@@ -171,7 +171,7 @@ row entity to the column entity and the matched indices are not duplicated
171171
in each row of ColToRowMatchIndices. If the column entity is not matched
172172
any row entity, set -1 in ColToRowMatchIndices.
173173
174-
Please note that the input DisMat can be LoDTensor (with LoD) or Tensor.
174+
Please note that the input DistMat can be LoDTensor (with LoD) or Tensor.
175175
If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size.
176176
If Tensor, the height of ColToRowMatchIndices is 1.
177177

0 commit comments

Comments
 (0)