@@ -34,12 +34,14 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
34
34
PADDLE_ENFORCE_EQ (x_dims.size (), 2UL , " Input(X)'s rank should be 2." );
35
35
PADDLE_ENFORCE_EQ (label_dims.size (), 2UL ,
36
36
" Input(Label)'s rank should be 2." );
37
- PADDLE_ENFORCE_EQ (x_dims[0 ], label_dims[0 ],
38
- " The 1st dimension of Input(X) and Input(Label) should "
39
- " be equal." );
40
- PADDLE_ENFORCE_EQ (label_dims[1 ], 1UL ,
41
- " The 2nd dimension of "
42
- " Input(Label) should be 1." );
37
+ if (ctx->IsRuntime ()) {
38
+ PADDLE_ENFORCE_EQ (x_dims[0 ], label_dims[0 ],
39
+ " The 1st dimension of Input(X) and Input(Label) should "
40
+ " be equal." );
41
+ PADDLE_ENFORCE_EQ (label_dims[1 ], 1UL ,
42
+ " The 2nd dimension of "
43
+ " Input(Label) should be 1." );
44
+ }
43
45
ctx->SetOutputDim (" Y" , {x_dims[0 ], 1 });
44
46
ctx->ShareLoD (" X" , /* ->*/ " Y" );
45
47
}
@@ -74,17 +76,20 @@ class TeacherStudentSigmoidLossGradientOp
74
76
PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank should be 2." );
75
77
PADDLE_ENFORCE_EQ (dy_dims.size (), 2 , " Input(Y@Grad)'s rank should be 2." );
76
78
PADDLE_ENFORCE_EQ (label_dims.size (), 2 , " Input(Label)'s rank should be 2." );
77
- PADDLE_ENFORCE_EQ (x_dims[0 ], label_dims[0 ],
78
- " The 1st dimension of Input(X) and Input(Label) should "
79
- " be equal." );
80
- PADDLE_ENFORCE_EQ (x_dims[0 ], dy_dims[0 ],
81
- " The 1st dimension of Input(X) and Input(Y@Grad) should "
82
- " be equal." );
83
- PADDLE_ENFORCE_EQ (dy_dims[1 ], 1 ,
84
- " The 2nd dimension of Input(Y@Grad) should be 1." );
85
- PADDLE_ENFORCE_EQ (label_dims[1 ], 1 ,
86
- " When Attr(soft_label) == false, the 2nd dimension of "
87
- " Input(Label) should be 1." );
79
+ if (ctx->IsRuntime ()) {
80
+ PADDLE_ENFORCE_EQ (x_dims[0 ], label_dims[0 ],
81
+ " The 1st dimension of Input(X) and Input(Label) should "
82
+ " be equal." );
83
+ PADDLE_ENFORCE_EQ (
84
+ x_dims[0 ], dy_dims[0 ],
85
+ " The 1st dimension of Input(X) and Input(Y@Grad) should "
86
+ " be equal." );
87
+ PADDLE_ENFORCE_EQ (dy_dims[1 ], 1 ,
88
+ " The 2nd dimension of Input(Y@Grad) should be 1." );
89
+ PADDLE_ENFORCE_EQ (label_dims[1 ], 1 ,
90
+ " When Attr(soft_label) == false, the 2nd dimension of "
91
+ " Input(Label) should be 1." );
92
+ }
88
93
ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
89
94
ctx->ShareLoD (" X" , framework::GradVarName (" X" ));
90
95
}
0 commit comments