Skip to content

Commit dc202c2

Browse files
seiriosPlusphlrain
authored andcommitted
fix/positive negative pair op (#16895)
* fix infershape in runtime * fix infershape in runtime test=develop * fix infershape in runtime
1 parent b20586c commit dc202c2

File tree

1 file changed

+24
-16
lines changed

1 file changed

+24
-16
lines changed

paddle/fluid/operators/positive_negative_pair_op.cc

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -61,23 +61,31 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
6161
auto query_dim = ctx->GetInputDim("QueryID");
6262
PADDLE_ENFORCE_EQ(score_dim.size(), 2, "Score should be a 2-D tensor.");
6363
PADDLE_ENFORCE_EQ(label_dim.size(), 2, "Label should be a 2-D tensor.");
64-
PADDLE_ENFORCE_EQ(
65-
label_dim[0], score_dim[0],
66-
"Tensor Score and Label should have the same height (batch size).");
67-
PADDLE_ENFORCE_EQ(label_dim[1], 1,
68-
"The width of Label should be 1, i.e. each item should "
69-
"have a scalar label.");
70-
PADDLE_ENFORCE(query_dim == label_dim,
71-
"QueryID should have the same shape as Label.");
72-
if (ctx->HasInput("Weight")) {
73-
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
74-
"Weight should have the same shape as Label.");
64+
65+
if (ctx->IsRuntime() ||
66+
(score_dim[0] > 0 && label_dim[0] > 0 && query_dim[0] > 0)) {
67+
PADDLE_ENFORCE_EQ(
68+
label_dim[0], score_dim[0],
69+
"Tensor Score and Label should have the same height (batch size).");
70+
71+
PADDLE_ENFORCE_EQ(label_dim[1], 1,
72+
"The width of Label should be 1, i.e. each item should "
73+
"have a scalar label.");
74+
75+
PADDLE_ENFORCE(query_dim == label_dim,
76+
"QueryID should have the same shape as Label.");
77+
78+
if (ctx->HasInput("Weight")) {
79+
PADDLE_ENFORCE(ctx->GetInputDim("Weight") == label_dim,
80+
"Weight should have the same shape as Label.");
81+
}
82+
83+
int column = ctx->Attrs().Get<int>("column");
84+
auto depth = score_dim[1];
85+
PADDLE_ENFORCE(column < depth && column >= -depth,
86+
"Attribute column should be in the range of [-%l, %l)",
87+
depth, depth);
7588
}
76-
int column = ctx->Attrs().Get<int>("column");
77-
auto depth = score_dim[1];
78-
PADDLE_ENFORCE(column < depth && column >= -depth,
79-
"Attribute column should be in the range of [-%l, %l)",
80-
depth, depth);
8189

8290
ctx->SetOutputDim("PositivePair", scalar_dim);
8391
ctx->SetOutputDim("NegativePair", scalar_dim);

0 commit comments

Comments
 (0)