Skip to content

Commit 45eabb8

Browse files
lcy-sesowangkuiyi
authored andcommitted
Add the crf_decoding operator. (#5352)
* proj init. * add unittest and implementation.
1 parent b0b26da commit 45eabb8

File tree

6 files changed

+447
-36
lines changed

6 files changed

+447
-36
lines changed

paddle/operators/crf_decoding_op.cc

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/crf_decoding_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
class CRFDecodingOpMaker : public framework::OpProtoAndCheckerMaker {
20+
public:
21+
CRFDecodingOpMaker(framework::OpProto* proto,
22+
framework::OpAttrChecker* op_checker)
23+
: OpProtoAndCheckerMaker(proto, op_checker) {
24+
AddInput("Emission",
25+
"(LoDTensor, default: LoDTensor<float>). A LoDTensor with shape "
26+
"[N x D] where N is the size of the mini-batch and D is the total "
27+
"tag number. This input is the unscaled emission weight matrix of "
28+
"the linear_chain_crf operator.");
29+
AddInput(
30+
"Transition",
31+
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
32+
"This input is the transition weights learned by the linear_chain_crf "
33+
"operator, denoted as w. The 1st row of w are transition weights for "
34+
"the start mask. The 2nd row of w are transition weights for the end "
35+
"mask. Transition weights between other tags begin from the 3rd row of "
36+
"w. See more details in comments of the linear_chain_crf operator.");
37+
AddInput(
38+
"Label",
39+
"(LoDTensor, LoDTensor<int>). The ground truth with shape "
40+
"[N x 1]. This input is optional. See more details in the operator's "
41+
"comments.")
42+
.AsDispensable();
43+
AddOutput("ViterbiPath",
44+
"(LoDTensor, LoDTensor<int>). The decoding results. What to "
45+
"return changes depending on whether the Input(Label) (the groud "
46+
"truth) is given. See more details in the operator's comment.");
47+
AddComment(R"DOC(
48+
The crf_decoding operator reads the emission feature weights and the transition
49+
freature weights learned by the linear_chain_crf operator. It implements the
50+
Viterbi algorithm which is a dynamic programming algorithm for finding the most
51+
likely sequence of hidden states, called the Viterbi path, that results in a
52+
sequence of observed tags.
53+
54+
The output of this operator changes according to whether Input(Label) is given:
55+
56+
1. Input(Label) is given:
57+
58+
This happens in training. This operator is used to co-work with the chunk_eval
59+
operator.
60+
61+
When Input(Label) is given, the crf_decoding operator returns a row vector
62+
with shape [N x 1] whose values are fixed to be 0, indicating an incorrect
63+
prediction, or 1 indicating a tag is correctly predicted. Such an ouput is the
64+
input to chunk_eval operator.
65+
66+
2. Input(Label) is not given:
67+
68+
This is the standard decoding process.
69+
70+
The crf_decoding operator returns a row vecotr with shape [N x 1] whose values
71+
range from 0 to maximum tag number - 1. Each element indicates an index of a
72+
predicted tag.
73+
)DOC");
74+
}
75+
};
76+
77+
class CRFDecodingOp : public framework::OperatorWithKernel {
78+
public:
79+
using framework::OperatorWithKernel::OperatorWithKernel;
80+
81+
void InferShape(framework::InferShapeContext* ctx) const override {
82+
PADDLE_ENFORCE(ctx->HasInput("Emission"),
83+
"Input(Emission) should be not null.");
84+
PADDLE_ENFORCE(ctx->HasInput("Transition"),
85+
"Input(Transition) should be not null.");
86+
87+
PADDLE_ENFORCE(ctx->HasOutput("ViterbiPath"),
88+
"Output(ViterbiPath) should be not null.");
89+
90+
auto emission_dims = ctx->GetInputDim("Emission");
91+
PADDLE_ENFORCE_EQ(emission_dims.size(), 2UL,
92+
"The Input(Emission) should be a 2-D tensor.");
93+
PADDLE_ENFORCE(emission_dims[0], "An empty mini-batch is not allowed.");
94+
95+
auto transition_dims = ctx->GetInputDim("Transition");
96+
PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL,
97+
"The Input(Transition) should be a 2-D tensor.");
98+
PADDLE_ENFORCE_EQ(
99+
transition_dims[0] - 2, transition_dims[1],
100+
"An invalid dimension for the Input(Transition), which should "
101+
"be a 2-D tensor with shape [(D + 2) x D].");
102+
PADDLE_ENFORCE_EQ(
103+
emission_dims[1], transition_dims[1],
104+
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
105+
"should be equal to the tag number.");
106+
107+
if (ctx->HasInput("Label")) {
108+
auto label_dims = ctx->GetInputDim("Label");
109+
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
110+
"The Input(Label) should be a 2-D tensor with the 2nd "
111+
"dimensions fixed to 1.");
112+
PADDLE_ENFORCE_EQ(
113+
emission_dims[0], label_dims[0],
114+
"The height of Input(Emission) and the height of Input(Label) "
115+
"should be the same.");
116+
}
117+
118+
ctx->ShareLoD("Emission", /*->*/ "ViterbiPath");
119+
ctx->SetOutputDim("ViterbiPath", {emission_dims[0], 1});
120+
}
121+
122+
protected:
123+
framework::DataType IndicateDataType(
124+
const framework::ExecutionContext& ctx) const override {
125+
return framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type());
126+
}
127+
};
128+
} // namespace operators
129+
} // namespace paddle
130+
131+
namespace ops = paddle::operators;
132+
REGISTER_OP_WITHOUT_GRADIENT(crf_decoding, ops::CRFDecodingOp,
133+
ops::CRFDecodingOpMaker);
134+
REGISTER_OP_CPU_KERNEL(
135+
crf_decoding, ops::CRFDecodingOpKernel<paddle::platform::CPUPlace, float>,
136+
ops::CRFDecodingOpKernel<paddle::platform::CPUPlace, double>);

paddle/operators/crf_decoding_op.h

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
#include "paddle/operators/math/math_function.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
using framework::LoDTensor;
24+
using framework::LoD;
25+
using framework::Tensor;
26+
27+
template <typename Place, typename T>
28+
class CRFDecodingOpKernel : public framework::OpKernel<T> {
29+
public:
30+
void Compute(const framework::ExecutionContext& ctx) const override {
31+
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
32+
"The crf_decoding operator can only run on CPU.");
33+
34+
auto* emission_weights = ctx.Input<LoDTensor>("Emission");
35+
auto* transition_weights = ctx.Input<Tensor>("Transition");
36+
auto* label = ctx.Input<LoDTensor>("Label");
37+
auto* decoded_path = ctx.Output<Tensor>("ViterbiPath");
38+
39+
PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL,
40+
"The Input(Emission) should be a sequence.");
41+
auto lod = emission_weights->lod();
42+
PADDLE_ENFORCE(lod.size(), "Input(Emission) must be a sequence.");
43+
const size_t level = 0;
44+
const size_t seq_num = lod[level].size() - 1;
45+
46+
int* path = decoded_path->mutable_data<int>(platform::CPUPlace());
47+
math::SetConstant<platform::CPUPlace, int>()(ctx.device_context(),
48+
decoded_path, 0);
49+
for (size_t i = 0; i < seq_num; ++i) {
50+
int start_pos = static_cast<int>(lod[level][i]);
51+
int end_pos = static_cast<int>(lod[level][i + 1]);
52+
Tensor decoded_path_one_seq = decoded_path->Slice(start_pos, end_pos);
53+
Decode(emission_weights->Slice(start_pos, end_pos), *transition_weights,
54+
&decoded_path_one_seq);
55+
}
56+
57+
if (label) {
58+
PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL,
59+
"The Input(Label) should be a sequence.");
60+
const int* label_value = label->data<int>();
61+
size_t batch_size = emission_weights->dims()[0];
62+
for (size_t i = 0; i < batch_size; ++i) {
63+
path[i] = label_value[i] == path[i] ? 1 : 0;
64+
}
65+
}
66+
}
67+
68+
private:
69+
void Decode(const Tensor& emission_weights, const Tensor& transition_weights,
70+
Tensor* decoded_path) const {
71+
auto emission_dims = emission_weights.dims();
72+
const size_t seq_len = emission_dims[0];
73+
const size_t tag_num = emission_dims[1];
74+
75+
const size_t state_trans_base_idx = 2;
76+
77+
const T* x = emission_weights.data<T>();
78+
const T* w = transition_weights.data<T>();
79+
int* path = decoded_path->data<int>();
80+
81+
// alpha is a memo table. An element alpha(k, v) records the score of the
82+
// best sequence of tags from position 1 to position k with v being the end
83+
// tag.
84+
Tensor alpha;
85+
T* alpha_value = alpha.mutable_data<T>(emission_dims, platform::CPUPlace());
86+
Tensor track;
87+
int* track_value =
88+
track.mutable_data<int>(emission_dims, platform::CPUPlace());
89+
90+
for (size_t i = 0; i < tag_num; ++i) alpha_value[i] = w[i] + x[i];
91+
92+
for (size_t k = 1; k < seq_len; ++k) {
93+
for (size_t i = 0; i < tag_num; ++i) {
94+
T max_score = -std::numeric_limits<T>::max();
95+
int max_j = 0;
96+
for (size_t j = 0; j < tag_num; ++j) {
97+
T score = alpha_value[(k - 1) * tag_num + j] +
98+
w[(j + state_trans_base_idx) * tag_num + i];
99+
if (score > max_score) {
100+
max_score = score;
101+
max_j = j;
102+
}
103+
}
104+
105+
alpha_value[k * tag_num + i] = max_score + x[k * tag_num + i];
106+
track_value[k * tag_num + i] = max_j;
107+
}
108+
}
109+
110+
T max_score = -std::numeric_limits<T>::max();
111+
int max_i = 0;
112+
for (size_t i = 0; i < tag_num; ++i) {
113+
T score = alpha_value[(seq_len - 1) * tag_num + i] + w[tag_num + i];
114+
if (score > max_score) {
115+
max_score = score;
116+
max_i = i;
117+
}
118+
}
119+
path[seq_len - 1] = max_i;
120+
for (int k = seq_len - 1; k >= 1; --k) {
121+
path[k - 1] = max_i = track_value[k * tag_num + max_i];
122+
}
123+
}
124+
};
125+
126+
} // namespace operators
127+
} // namespace paddle

paddle/operators/cross_entropy_op.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
4949
}
5050

5151
protected:
52-
// Explicitly set that data type of the output of the cross_entropy operator
52+
// Explicitly set that the data type of computation kernel of cross_entropy
5353
// is determined by its input "X".
5454
framework::DataType IndicateDataType(
5555
const framework::ExecutionContext& ctx) const override {
@@ -96,7 +96,8 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
9696
}
9797

9898
protected:
99-
// CrossEntropy's data type just determined by "X"
99+
// Explicitly set that the data type of computation kernel of cross_entropy
100+
// is determined by its input "X".
100101
framework::DataType IndicateDataType(
101102
const framework::ExecutionContext& ctx) const override {
102103
return framework::ToDataType(ctx.Input<Tensor>("X")->type());

paddle/operators/linear_chain_crf_op.cc

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,44 @@ class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
2222
LinearChainCRFOpMaker(framework::OpProto* proto,
2323
framework::OpAttrChecker* op_checker)
2424
: OpProtoAndCheckerMaker(proto, op_checker) {
25-
AddInput(
26-
"Emission",
27-
"(LoDTensor, default: LoDTensor<float>). "
28-
"The unscaled emission weight matrix for the linear chain CRF. "
29-
"This input is a LoDTensor with shape [N x D] where N is the size of "
30-
"the mini-batch and D is the total tag number.");
31-
AddInput(
32-
"Transition",
33-
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
34-
"The learnable parameter for the linear_chain_crf operator. "
35-
"See more details in the operator's comments.");
36-
AddInput(
37-
"Label",
38-
"(LoDTensor, default: LoDTensor<int>). The ground truth which is a 2-D "
39-
"LoDTensor with shape [N x 1], where N is the total element number in "
40-
"a mini-batch.");
25+
AddInput("Emission",
26+
"(LoDTensor, default: LoDTensor<float>). "
27+
"A 2-D LoDTensor with shape [N x D] where N is the size of the "
28+
"mini-batch and D is the total tag number. The unscaled emission "
29+
"weight matrix for the linear chain CRF. ");
30+
AddInput("Transition",
31+
"(Tensor, default: Tensor<float>). A 2-D Tensor with shape "
32+
"[(D + 2) x D]. The learnable parameter for the linear_chain_crf "
33+
"operator. See more details in the operator's comments.");
34+
AddInput("Label",
35+
"(LoDTensor, default: LoDTensor<int>). A LoDTensor with shape "
36+
"[N x 1], where N is the total element number in a mini-batch. "
37+
"The ground truth.");
4138
AddOutput(
4239
"Alpha",
43-
"Tensor, default: Tensor<float>. The forward vectors for the entire "
44-
"batch. A two dimensional tensor with shape [N x D], "
45-
"denoted as \f$\alpha\f$. \f$\alpha$\f is a memo table used to "
46-
"calculate the normalization factor in CRF. \f$\alpha[k, v]$\f stores "
47-
"the unnormalized probabilites of all possible unfinished sequences of "
48-
"tags that end at position \f$k$\f with tag \f$v$\f. For each \f$k$\f, "
40+
"(Tensor, default: Tensor<float>). A 2-D Tensor with shape [N x D]. "
41+
"The forward vectors for the entire batch. Denote it as \f$\alpha\f$. "
42+
"\f$\alpha$\f is a memo table used to calculate the normalization "
43+
"factor in CRF. \f$\alpha[k, v]$\f stores the unnormalized "
44+
"probabilites of all possible unfinished sequences of tags that end at "
45+
"position \f$k$\f with tag \f$v$\f. For each \f$k$\f, "
4946
"\f$\alpha[k, v]$\f is a vector of length \f$D$\f with a component for "
5047
"each tag value \f$v$\f. This vector is called a forward vecotr and "
5148
"will also be used in backward computations.")
5249
.AsIntermediate();
53-
AddOutput("EmissionExps",
54-
"The exponentials of Input(Emission). This is an intermediate "
55-
"computational result in forward computation, and will be reused "
56-
"in backward computation.")
50+
AddOutput(
51+
"EmissionExps",
52+
"(Tensor, default: Tensor<float>). A 2-D Tensor with shape [N x D]. "
53+
"The exponentials of Input(Emission). This is an intermediate "
54+
"computational result in forward computation, and will be reused in "
55+
"backward computation.")
5756
.AsIntermediate();
58-
AddOutput("TransitionExps",
59-
"The exponentials of Input(Transition). This is an intermediate "
60-
"computational result in forward computation, and will be reused "
61-
"in backward computation.")
57+
AddOutput(
58+
"TransitionExps",
59+
"(Tensor, default: Tensor<float>). A 2-D Tensor with shape "
60+
"[(D + 2) x D]. The exponentials of Input(Transition). This is an "
61+
"intermediate computational result in forward computation, and "
62+
"will be reused in backward computation.")
6263
.AsIntermediate();
6364
AddOutput(
6465
"LogLikelihood",
@@ -179,8 +180,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
179180
}
180181

181182
protected:
182-
// Explicitly set that the data type of output of the linear_chain_crf
183-
// operator is determined by its input "Emission".
183+
// Explicitly set that the data type of computation kernel of linear_chain_crf
184+
// is determined by its input "Emission".
184185
framework::DataType IndicateDataType(
185186
const framework::ExecutionContext& ctx) const override {
186187
return framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type());

paddle/operators/linear_chain_crf_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel<T> {
134134

135135
Tensor emission_row_max;
136136
emission_row_max.mutable_data<T>(
137-
framework::make_ddim({static_cast<int>(batch_size), 1}),
137+
framework::make_ddim({static_cast<int64_t>(batch_size), 1}),
138138
platform::CPUPlace());
139139

140140
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
@@ -273,7 +273,7 @@ class LinearChainCRFOpKernel : public framework::OpKernel<T> {
273273

274274
const int* lbl = label.data<int>();
275275
PADDLE_ENFORCE_LT(
276-
*std::max_element(lbl, lbl + seq_length), tag_num,
276+
static_cast<size_t>(*std::max_element(lbl, lbl + seq_length)), tag_num,
277277
"An invalid tag label that execesses the largest tag number.");
278278

279279
// Calculate the nominator part, which depends on the label sequence.

0 commit comments

Comments
 (0)