@@ -31,15 +31,22 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
31
31
32
32
auto x_dims = ctx->GetInputDim (" X" );
33
33
auto labels_dims = ctx->GetInputDim (" Label" );
34
- PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank should be 2." );
35
- PADDLE_ENFORCE_EQ (labels_dims.size (), 2 ,
36
- " Input(Label)'s rank should be 2." );
37
- PADDLE_ENFORCE_EQ (x_dims[0 ], labels_dims[0 ],
38
- " The 1st dimension of Input(X) and Input(Label) should "
39
- " be equal." );
40
- PADDLE_ENFORCE_EQ (x_dims[1 ], labels_dims[1 ],
41
- " The 2nd dimension of Input(X) and Input(Label) should "
42
- " be equal." );
34
+
35
+ int rank = x_dims.size ();
36
+ PADDLE_ENFORCE_EQ (rank, labels_dims.size (),
37
+ " Input(X) and Input(Label) shall have the same rank." );
38
+ bool check = true ;
39
+ if ((!ctx->IsRuntime ()) && (framework::product (x_dims) <= 0 ||
40
+ framework::product (labels_dims) <= 0 )) {
41
+ check = false ;
42
+ }
43
+
44
+ if (check) {
45
+ PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank),
46
+ framework::slice_ddim (labels_dims, 0 , rank),
47
+ " Input(X) and Input(Label) shall have the same shape "
48
+ " except the last dimension." );
49
+ }
43
50
44
51
ctx->ShareDim (" X" , /* ->*/ " Out" );
45
52
ctx->ShareLoD (" X" , /* ->*/ " Out" );
@@ -62,23 +69,24 @@ class SigmoidCrossEntropyWithLogitsGradOp
62
69
auto x_dims = ctx->GetInputDim (" X" );
63
70
auto labels_dims = ctx->GetInputDim (" Label" );
64
71
auto dout_dims = ctx->GetInputDim (framework::GradVarName (" Out" ));
65
- PADDLE_ENFORCE_EQ (x_dims.size (), 2 , " Input(X)'s rank should be 2." );
66
- PADDLE_ENFORCE_EQ (labels_dims.size (), 2 ,
67
- " Input(Label)'s rank should be 2." );
68
- PADDLE_ENFORCE_EQ (dout_dims.size (), 2 ,
69
- " Input(Out@Grad)'s rank should be 2." );
70
- PADDLE_ENFORCE_EQ (x_dims[0 ], labels_dims[0 ],
71
- " The 1st dimension of Input(X) and Input(Label) should "
72
- " be equal." );
73
- PADDLE_ENFORCE_EQ (x_dims[1 ], labels_dims[1 ],
74
- " The 2nd dimension of Input(X) and Input(Label) should "
75
- " be equal." );
76
- PADDLE_ENFORCE_EQ (x_dims[0 ], dout_dims[0 ],
77
- " The 1st dimension of Input(X) and Input(Out@Grad) "
78
- " should be equal." );
79
- PADDLE_ENFORCE_EQ (x_dims[1 ], dout_dims[1 ],
80
- " The 2nd dimension of Input(X) and Input(Out@Grad) "
81
- " should be equal." );
72
+
73
+ int rank = x_dims.size ();
74
+ bool check = true ;
75
+ if ((!ctx->IsRuntime ()) && (framework::product (x_dims) <= 0 ||
76
+ framework::product (labels_dims) <= 0 )) {
77
+ check = false ;
78
+ }
79
+
80
+ if (check) {
81
+ PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 0 , rank),
82
+ framework::slice_ddim (labels_dims, 0 , rank),
83
+ " Input(X) and Input(Label) shall have the same shape." );
84
+
85
+ PADDLE_ENFORCE_EQ (
86
+ framework::slice_ddim (x_dims, 0 , rank),
87
+ framework::slice_ddim (dout_dims, 0 , rank),
88
+ " Input(X) and Input(Out@Grad) shall have the same shape." );
89
+ }
82
90
83
91
ctx->SetOutputDim (framework::GradVarName (" X" ), x_dims);
84
92
}
0 commit comments