4
4
you may not use this file except in compliance with the License.
5
5
You may obtain a copy of the License at
6
6
7
- http://www.apache.org/licenses/LICENSE-2.0
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
8
9
9
Unless required by applicable law or agreed to in writing, software
10
10
distributed under the License is distributed on an "AS IS" BASIS,
@@ -35,9 +35,10 @@ class RankLossOp : public framework::OperatorWithKernel {
35
35
auto right_dims = ctx->GetInputDim (" Right" );
36
36
37
37
PADDLE_ENFORCE ((label_dims == left_dims) && (left_dims == right_dims),
38
- " All inputs must have the same size" );
39
- PADDLE_ENFORCE ((label_dims.size () == 2 ) && (label_dims[1 ] == 1 ),
40
- " All inputs must be row vector with size batch_size x 1." );
38
+ " All inputs must have the same size." );
39
+ PADDLE_ENFORCE (
40
+ (label_dims.size () == 2 ) && (label_dims[1 ] == 1 ),
41
+ " All inputs must be 2-D tensors with shape [batch_size x 1]." );
41
42
ctx->SetOutputDim (" Out" , label_dims);
42
43
}
43
44
};
@@ -48,10 +49,17 @@ class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
48
49
framework::OpAttrChecker *op_checker)
49
50
: OpProtoAndCheckerMaker(proto, op_checker) {
50
51
AddInput (" Label" ,
51
- " The label indicating A ranked higher than B or not, row vector." );
52
- AddInput (" Left" , " The output of RankNet for doc A, vector." );
53
- AddInput (" Right" , " The output of RankNet for doc B, vetor." );
54
- AddOutput (" Out" , " The output loss of RankLoss operator, vector." );
52
+ " (2-D Tensor with shape [batch_size x 1]) "
53
+ " The label indicating A ranked higher than B or not." );
54
+ AddInput (" Left" ,
55
+ " (2-D Tensor with shape [batch_size x 1]) "
56
+ " The output of RankNet for doc A." );
57
+ AddInput (" Right" ,
58
+ " (2-D Tensor with shape [batch_size x 1]) "
59
+ " The output of RankNet for doc B." );
60
+ AddOutput (" Out" ,
61
+ " (2-D Tensor with shape [batch_size x 1]) "
62
+ " The output loss of RankLoss operator." );
55
63
AddComment (R"DOC(
56
64
RankLoss Operator.
57
65
@@ -65,16 +73,17 @@ P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of
65
73
the input pair.
66
74
67
75
The RankLoss operator takes three inputs: Left (o_i), Right (o_j) and Label
68
- (P_{i,j}), which represent the output of RankNet for the two docs and the label,
69
- respectively, and yields the rank loss C_{i,j} using the following equation:
76
+ (P_{i,j}), which represent the output score of RankNet for the two docs and
77
+ the label respectively, and yields the rank loss C_{i,j} using the following
78
+ equation:
70
79
71
- \f $$
72
- C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\
80
+ $$
81
+ C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + \ log(1 + e^{o_{i,j}}) \\
73
82
o_{i,j} = o_i - o_j \\
74
83
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
75
- \f $$
84
+ $$
76
85
77
- The operator can take inputs of one sample or in batch .
86
+ The operator can take batch inputs with size batch_size (batch_size >= 1) .
78
87
79
88
)DOC" );
80
89
}
0 commit comments