@@ -40,17 +40,27 @@ class CosSimOp : public framework::OperatorWithKernel {
40
40
auto x_dims = ctx->GetInputDim (" X" );
41
41
auto y_dims = ctx->GetInputDim (" Y" );
42
42
43
- PADDLE_ENFORCE_EQ (x_dims.size (), y_dims.size (),
44
- " Ranks of Input(X) and Input(Y) must be equal." );
45
- PADDLE_ENFORCE_GE (x_dims.size (), 2 ,
46
- " Rank of Input(X) must not be less than 2." );
47
- PADDLE_ENFORCE_EQ (framework::slice_ddim (x_dims, 1 , x_dims.size ()),
48
- framework::slice_ddim (y_dims, 1 , y_dims.size ()),
49
- " All dimensions except the 1st of Input(X) and Input(Y) "
50
- " must be equal." );
51
- PADDLE_ENFORCE (x_dims[0 ] == y_dims[0 ] || y_dims[0 ] == 1 ,
52
- " The 1st dimension of Input(Y) must be equal to Input(X) or"
53
- " just 1 (which will be broadcasted to match Input(X))." );
43
+ bool check = true ;
44
+ if ((!ctx->IsRuntime ()) &&
45
+ (framework::product (x_dims) <= 0 || framework::product (y_dims) <= 0 )) {
46
+ check = false ;
47
+ }
48
+
49
+ if (check) {
50
+ PADDLE_ENFORCE_EQ (x_dims.size (), y_dims.size (),
51
+ " Ranks of Input(X) and Input(Y) must be equal." );
52
+ PADDLE_ENFORCE_GE (x_dims.size (), 2 ,
53
+ " Rank of Input(X) must not be less than 2." );
54
+ PADDLE_ENFORCE_EQ (
55
+ framework::slice_ddim (x_dims, 1 , x_dims.size ()),
56
+ framework::slice_ddim (y_dims, 1 , y_dims.size ()),
57
+ " All dimensions except the 1st of Input(X) and Input(Y) "
58
+ " must be equal." );
59
+ PADDLE_ENFORCE (
60
+ x_dims[0 ] == y_dims[0 ] || y_dims[0 ] == 1 ,
61
+ " The 1st dimension of Input(Y) must be equal to Input(X) or"
62
+ " just 1 (which will be broadcasted to match Input(X))." );
63
+ }
54
64
55
65
// resize tensor
56
66
ctx->SetOutputDim (" Out" , {x_dims[0 ], 1 });
0 commit comments