@@ -61,23 +61,31 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
61
61
auto query_dim = ctx->GetInputDim (" QueryID" );
62
62
PADDLE_ENFORCE_EQ (score_dim.size (), 2 , " Score should be a 2-D tensor." );
63
63
PADDLE_ENFORCE_EQ (label_dim.size (), 2 , " Label should be a 2-D tensor." );
64
- PADDLE_ENFORCE_EQ (
65
- label_dim[0 ], score_dim[0 ],
66
- " Tensor Score and Label should have the same height (batch size)." );
67
- PADDLE_ENFORCE_EQ (label_dim[1 ], 1 ,
68
- " The width of Label should be 1, i.e. each item should "
69
- " have a scalar label." );
70
- PADDLE_ENFORCE (query_dim == label_dim,
71
- " QueryID should have the same shape as Label." );
72
- if (ctx->HasInput (" Weight" )) {
73
- PADDLE_ENFORCE (ctx->GetInputDim (" Weight" ) == label_dim,
74
- " Weight should have the same shape as Label." );
64
+
65
+ if (ctx->IsRuntime () ||
66
+ (score_dim[0 ] > 0 && label_dim[0 ] > 0 && query_dim[0 ] > 0 )) {
67
+ PADDLE_ENFORCE_EQ (
68
+ label_dim[0 ], score_dim[0 ],
69
+ " Tensor Score and Label should have the same height (batch size)." );
70
+
71
+ PADDLE_ENFORCE_EQ (label_dim[1 ], 1 ,
72
+ " The width of Label should be 1, i.e. each item should "
73
+ " have a scalar label." );
74
+
75
+ PADDLE_ENFORCE (query_dim == label_dim,
76
+ " QueryID should have the same shape as Label." );
77
+
78
+ if (ctx->HasInput (" Weight" )) {
79
+ PADDLE_ENFORCE (ctx->GetInputDim (" Weight" ) == label_dim,
80
+ " Weight should have the same shape as Label." );
81
+ }
82
+
83
+ int column = ctx->Attrs ().Get <int >(" column" );
84
+ auto depth = score_dim[1 ];
85
+ PADDLE_ENFORCE (column < depth && column >= -depth,
86
+ " Attribute column should be in the range of [-%l, %l)" ,
87
+ depth, depth);
75
88
}
76
- int column = ctx->Attrs ().Get <int >(" column" );
77
- auto depth = score_dim[1 ];
78
- PADDLE_ENFORCE (column < depth && column >= -depth,
79
- " Attribute column should be in the range of [-%l, %l)" ,
80
- depth, depth);
81
89
82
90
ctx->SetOutputDim (" PositivePair" , scalar_dim);
83
91
ctx->SetOutputDim (" NegativePair" , scalar_dim);
0 commit comments