|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/operators/linear_chain_crf_op.h" |
| 16 | + |
| 17 | +namespace paddle { |
| 18 | +namespace operators { |
| 19 | + |
| 20 | +class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker { |
| 21 | + public: |
| 22 | + LinearChainCrfOpMaker(framework::OpProto* proto, |
| 23 | + framework::OpAttrChecker* op_checker) |
| 24 | + : OpProtoAndCheckerMaker(proto, op_checker) { |
| 25 | + AddInput( |
| 26 | + "Emission", |
| 27 | + "(LoDTensor, default: LoDTensor<float>). " |
| 28 | + "The unscaled emission weight matrix for the linear chain CRF. " |
| 29 | + "This input is a LoDTensor with shape [N x D] where N is the total " |
| 30 | + "element number of all input squences in a mini-batch, " |
| 31 | + "and D is the total tag number."); |
| 32 | + AddInput( |
| 33 | + "Transition", |
| 34 | + "(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. " |
| 35 | + "The learnable parameter for linear_chain_crf operator. " |
| 36 | + "See more details in the operator's comments."); |
| 37 | + AddInput( |
| 38 | + "Label", |
| 39 | + "(LoDTensor, default: LoDTensor<int>). The ground truth which is a 2-D " |
| 40 | + "LoDTensor with shape [N x 1], where N is the total element number in " |
| 41 | + "a mini-batch."); |
| 42 | + AddOutput( |
| 43 | + "Alpha", |
| 44 | + "Tensor, default: Tensor<float>. The forward vectors for the entire " |
| 45 | + "batch. A two dimensional tensor with shape [N x D], " |
| 46 | + "denoted as \f$\alpha\f$. \f$\alpha$\f is a memo table used to " |
| 47 | + "calculate the normalization factor in CRF. \f$\alpha[k, v]$\f stores " |
| 48 | + "the unnormalized probabilites of all possible unfinished sequences of " |
| 49 | + "tags that end at position \f$k$\f with tag \f$v$\f. For each \f$k$\f, " |
| 50 | + "\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for " |
| 51 | + "each tag value \f$v$\f. This vector is called a forward vecotr and " |
| 52 | + "will also be used in backward computations.") |
| 53 | + .AsIntermediate(); |
| 54 | + AddOutput( |
| 55 | + "LogLikelihood", |
| 56 | + "(Tensor, default: Tensor<float>). The logarithm of the conditional " |
| 57 | + "likelihood of each training sample in a mini-batch. This is a 2-D " |
| 58 | + "tensor with shape [S x 1], where S is the sequence number in a " |
| 59 | + "mini-batch. " |
| 60 | + "Note: S is equal to the sequence number in a mini-batch. The output " |
| 61 | + "is no longer a LoDTensor."); |
| 62 | + AddComment(R"DOC( |
| 63 | +Conditional Random Field defines an undirected probabilistic graph with nodes |
| 64 | +denoting random variables and edges denoting dependencies between these |
| 65 | +variables. CRF learns the conditional probability \f$P(Y|X)\f$, where |
| 66 | +\f$X = (x_1, x_2, ... , x_n)\f$ are structured inputs and |
| 67 | +\f$Y = (y_1, y_2, ... , y_n)\f$ are labels for the inputs. |
| 68 | +
|
| 69 | +Linear chain CRF is a special case of CRF that is useful for sequence labeling |
| 70 | +task. Sequence labeling tasks do not assume a lot of conditional |
| 71 | +independences among inputs. They only concern about the input and the output |
| 72 | +being linear sequences. Thus, the graph model of CRF is a simple chain or |
| 73 | +a line, which results in a linear chain CRF. |
| 74 | +
|
| 75 | +This operator implements the Forward-Backward algorithm for linear chain CRF. |
| 76 | +Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference. |
| 77 | +
|
| 78 | +Equation: |
| 79 | +
|
| 80 | +- Denote the first input of this operator (Emission) as \f$x\f$ here. |
| 81 | +- The first D values of the second input (Transition) of this operator are for |
| 82 | +starting weights, denoted as \f$a\f$ here. |
| 83 | +- The next D values of the second input (Transition) of this operator are for |
| 84 | +ending weights, denoted as \f$b\f$ here. |
| 85 | +- The remaning values of the second input (Transition) are for transition |
| 86 | +weights, denoted as \f$w\f$ here. |
| 87 | +- Denote the third input of this operator (Label) as \f$s\f$ here. |
| 88 | +
|
| 89 | +The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as: |
| 90 | +\f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L} |
| 91 | + + \sum_{l=1}^L x_{s_l} |
| 92 | + + \sum_{l=2}^L w_{s_{l-1},s_l})\f$ |
| 93 | +where \f$Z\f$ is a normalization value so that the sum of \f$P(s)\f$ over |
| 94 | +all possible sequences is \f$1\f$, and \f$x\f$ is the emission feature weight |
| 95 | +to the linear chain CRF. |
| 96 | +
|
| 97 | +Finaly, the linear chain CRF operator outputs the logarithm of the conditional |
| 98 | +likelihood of each training sample in a mini-batch. |
| 99 | +
|
| 100 | +NOTE: |
| 101 | +1. The feature function for a CRF is made up of the emission features and the |
| 102 | +transition features. The emission feature weights are NOT computed in |
| 103 | +this operator. They MUST be computed first before this operator is called. |
| 104 | +
|
| 105 | +2. Because this operator performs globally normaliztion over all possible |
| 106 | +sequences internally, it expects UNSCALED emission feature weights. |
| 107 | +Please do not call this op with the emission feature being output of any |
| 108 | +nonlinear activation. |
| 109 | +
|
| 110 | +3. The 2nd dimension of the first input of this operator (Emission) MUST be |
| 111 | +equal to the tag number. |
| 112 | +
|
| 113 | +)DOC"); |
| 114 | + } |
| 115 | +}; |
| 116 | + |
| 117 | +class LinearChainCrfOp : public framework::OperatorWithKernel { |
| 118 | + public: |
| 119 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 120 | + |
| 121 | + protected: |
| 122 | + void InferShape(framework::InferShapeContextBase* ctx) const override {} |
| 123 | +}; |
| 124 | + |
| 125 | +class LinearChainCrfGradOp : public framework::OperatorWithKernel { |
| 126 | + public: |
| 127 | + using framework::OperatorWithKernel::OperatorWithKernel; |
| 128 | + |
| 129 | + protected: |
| 130 | + void InferShape(framework::InferShapeContextBase* ctx) const override {} |
| 131 | +}; |
| 132 | + |
| 133 | +} // namespace operators |
| 134 | +} // namespace paddle |
| 135 | + |
| 136 | +namespace ops = paddle::operators; |
| 137 | +REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker, |
| 138 | + linear_chain_crf_grad, ops::LinearChainCrfGradOp); |
| 139 | +REGISTER_OP_CPU_KERNEL(linear_chain_crf, ops::LinearChainCrfOpKernel<float>); |
| 140 | +REGISTER_OP_CPU_KERNEL(linear_chain_crf_grad, |
| 141 | + ops::LinearChainCrfGradOpKernel<float>); |
0 commit comments