Skip to content

Commit 2055c16

Browse files
committed
Merge pull request #16890 from colourful-tree/dev
fix teacher_student op infer
1 parent ece7451 commit 2055c16

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

paddle/fluid/operators/teacher_student_sigmoid_loss_op.cc

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
3434
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
3535
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
3636
"Input(Label)'s rank should be 2.");
37-
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
38-
"The 1st dimension of Input(X) and Input(Label) should "
39-
"be equal.");
40-
PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
41-
"The 2nd dimension of "
42-
"Input(Label) should be 1.");
37+
if (ctx->IsRuntime()) {
38+
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
39+
"The 1st dimension of Input(X) and Input(Label) should "
40+
"be equal.");
41+
PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
42+
"The 2nd dimension of "
43+
"Input(Label) should be 1.");
44+
}
4345
ctx->SetOutputDim("Y", {x_dims[0], 1});
4446
ctx->ShareLoD("X", /*->*/ "Y");
4547
}
@@ -74,17 +76,20 @@ class TeacherStudentSigmoidLossGradientOp
7476
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
7577
PADDLE_ENFORCE_EQ(dy_dims.size(), 2, "Input(Y@Grad)'s rank should be 2.");
7678
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
77-
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
78-
"The 1st dimension of Input(X) and Input(Label) should "
79-
"be equal.");
80-
PADDLE_ENFORCE_EQ(x_dims[0], dy_dims[0],
81-
"The 1st dimension of Input(X) and Input(Y@Grad) should "
82-
"be equal.");
83-
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
84-
"The 2nd dimension of Input(Y@Grad) should be 1.");
85-
PADDLE_ENFORCE_EQ(label_dims[1], 1,
86-
"When Attr(soft_label) == false, the 2nd dimension of "
87-
"Input(Label) should be 1.");
79+
if (ctx->IsRuntime()) {
80+
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
81+
"The 1st dimension of Input(X) and Input(Label) should "
82+
"be equal.");
83+
PADDLE_ENFORCE_EQ(
84+
x_dims[0], dy_dims[0],
85+
"The 1st dimension of Input(X) and Input(Y@Grad) should "
86+
"be equal.");
87+
PADDLE_ENFORCE_EQ(dy_dims[1], 1,
88+
"The 2nd dimension of Input(Y@Grad) should be 1.");
89+
PADDLE_ENFORCE_EQ(label_dims[1], 1,
90+
"When Attr(soft_label) == false, the 2nd dimension of "
91+
"Input(Label) should be 1.");
92+
}
8893
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
8994
ctx->ShareLoD("X", framework::GradVarName("X"));
9095
}

0 commit comments

Comments
 (0)