Skip to content

Commit cc220ee

Browse files
committed
add forward computation of crf operator.
1 parent cbcf11d commit cc220ee

File tree

7 files changed

+231
-49
lines changed

7 files changed

+231
-49
lines changed

paddle/framework/tensor.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,16 +114,19 @@ class Tensor {
114114
const platform::DeviceContext& ctx);
115115

116116
/**
117-
* @brief Return the slice of the tensor.
117+
* @brief Return a sub-tensor of the given tensor.
118118
*
119-
* @param[in] begin_idx The begin index of the slice.
120-
* @param[in] end_idx The end index of the slice.
119+
* @param[in] begin_idx The index of the start row(inclusive) to slice.
120+
* The index number begins from 0.
121+
* @param[in] end_idx The index of the end row(exclusive) to slice.
122+
* The index number begins from 0.
121123
*/
122124
template <typename T>
123125
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
124126

125127
platform::Place place() const {
126-
PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder");
128+
PADDLE_ENFORCE_NOT_NULL(
129+
holder_, "A holder must exist when calling the method place().");
127130
return holder_->place();
128131
}
129132

paddle/framework/tensor_impl.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,10 +168,11 @@ inline void Tensor::CopyFromVector(const std::vector<T>& src,
168168
template <typename T>
169169
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
170170
check_memory_size<T>();
171-
PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero.");
172-
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
171+
PADDLE_ENFORCE_GE(begin_idx, 0,
172+
"The start row index must be greater than 0.");
173+
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound.");
173174
PADDLE_ENFORCE_LT(begin_idx, end_idx,
174-
"Begin index must be less than end index.");
175+
"The start row index must be less than the end row index.");
175176

176177
if (dims_[0] == 1) {
177178
return *this;

paddle/operators/cross_entropy_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
4949
ctx->ShareLoD("X", /*->*/ "Y");
5050
}
5151

52-
// Explicitly set data type of output of the cross_entropy operator
52+
// Explicitly set that data type of the output of the cross_entropy operator
5353
// is determined by its input "X".
5454
framework::DataType IndicateDataType(
5555
const framework::ExecutionContext& ctx) const override {

paddle/operators/linear_chain_crf_op.cc

Lines changed: 192 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20+
using framework::LoDTensor;
21+
using framework::LoD;
22+
2023
class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
2124
public:
2225
LinearChainCrfOpMaker(framework::OpProto* proto,
@@ -77,14 +80,14 @@ Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
7780
7881
Equation:
7982
80-
- Denote the first input of this operator (Emission) as \f$x\f$ here.
81-
- The first D values of the second input (Transition) of this operator are for
82-
starting weights, denoted as \f$a\f$ here.
83-
- The next D values of the second input (Transition) of this operator are for
84-
ending weights, denoted as \f$b\f$ here.
85-
- The remaning values of the second input (Transition) are for transition
86-
weights, denoted as \f$w\f$ here.
87-
- Denote the third input of this operator (Label) as \f$s\f$ here.
83+
- Denote Input(Emission) to this operator as \f$x\f$ here.
84+
- The first D values of Input(Transition) to this operator are for starting
85+
weights, denoted as \f$a\f$ here.
86+
- The next D values of Input(Transition) of this operator are for ending
87+
weights, denoted as \f$b\f$ here.
88+
- The remaning values of Input(Transition) are for transition weights,
89+
denoted as \f$w\f$ here.
90+
- Denote Input(Label) as \f$s\f$ here.
8891
8992
The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as:
9093
\f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L}
@@ -107,8 +110,7 @@ sequences internally, it expects UNSCALED emission feature weights.
107110
Please do not call this op with the emission feature being output of any
108111
nonlinear activation.
109112
110-
3. The 2nd dimension of the first input of this operator (Emission) MUST be
111-
equal to the tag number.
113+
3. The 2nd dimension of Input(Emission) MUST be equal to the tag number.
112114
113115
)DOC");
114116
}
@@ -136,33 +138,188 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
136138
auto label_dims = ctx->GetInputDim("Label");
137139

138140
PADDLE_ENFORCE_EQ(emission_dims.size(), 2UL,
139-
"The input Emission should be a 2-D tensor.");
141+
"The Input(Emission) should be a 2-D tensor.");
140142
PADDLE_ENFORCE_EQ(transition_dims.size(), 2UL,
141-
"The input Transition should be a 2-D tensor.");
143+
"The Input(Transition) should be a 2-D tensor.");
142144
PADDLE_ENFORCE_EQ(
143-
transition_dims[0] + 2, transition_dims[1],
144-
"An invalid dimension for the input Transition, which should "
145+
transition_dims[0] - 2, transition_dims[1],
146+
"An invalid dimension for the Input(Transition), which should "
145147
"be a 2-D tensor with shape [D + 2 x D].");
146148
PADDLE_ENFORCE_EQ(
147149
emission_dims[1], transition_dims[1],
148-
"The 2nd dimension of the input Emission and the input Transition "
150+
"The 2nd dimension of the Input(Emission) and the Input(Transition) "
149151
"should be equal to the tag number.");
150152
PADDLE_ENFORCE(label_dims.size() == 2UL && label_dims[1] == 1UL,
151-
"The input Label should be a 2-D tensor "
152-
"with the 2nd dimensions fixed to 1.");
153+
"The Input(Label) should be a 2-D tensor with the 2nd "
154+
"dimensions fixed to 1.");
155+
PADDLE_ENFORCE_EQ(
156+
emission_dims[0], label_dims[0],
157+
"The height of Input(Emission) and the height of Input(Label) "
158+
"should be the same.");
153159

154160
ctx->SetOutputDim("Alpha", emission_dims);
161+
162+
// (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
163+
// is the sequence number in a mini-batch. The dimension set here should be
164+
// resized to its correct size in the function Compute.
155165
ctx->SetOutputDim("LogLikelihood", {emission_dims[0], 1});
156166
}
157167

158-
// Explicitly set data type of output of the linear_chain_crf operator
159-
// is determined by its input "Emission".
168+
// Explicitly set that the data type of output of the linear_chain_crf
169+
// operator is determined by its input "Emission".
160170
framework::DataType IndicateDataType(
161171
const framework::ExecutionContext& ctx) const override {
162172
return framework::ToDataType(ctx.Input<Tensor>("Emission")->type());
163173
}
164174
};
165175

176+
template <typename T>
177+
class LinearChainCrfOpKernel<platform::CPUPlace, T>
178+
: public framework::OpKernel<T> {
179+
public:
180+
void Compute(const framework::ExecutionContext& ctx) const override {
181+
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
182+
"This kernel only runs on CPU.");
183+
184+
auto* emission_weights = ctx.Input<LoDTensor>("Emission");
185+
auto* transition_weights = ctx.Input<Tensor>("Transition");
186+
auto* label = ctx.Input<LoDTensor>("Label");
187+
188+
auto in_lod = emission_weights->lod();
189+
// TODO(caoying) The checks related to LoD information should be
190+
// moved into InferShape once after the InferShape is refactored.
191+
PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL,
192+
"The Input(Emission) should be a sequence.");
193+
PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL,
194+
"The Input(Label) should be a sequence.");
195+
const size_t level = 0;
196+
197+
auto emission_dims = emission_weights->dims();
198+
const size_t seq_num = in_lod[level].size() - 1;
199+
200+
// TODO(caoying) These local variables seems to be created and destroied
201+
// every time this function is called. Will this bring additional overhead?
202+
Tensor emission_exps;
203+
Tensor emission_row_max;
204+
Tensor transition_exps;
205+
emission_exps.mutable_data<T>(emission_dims, platform::CPUPlace());
206+
emission_row_max.mutable_data<T>(
207+
framework::make_ddim({emission_dims[0], 1}), platform::CPUPlace());
208+
transition_exps.mutable_data<T>(transition_weights->dims(),
209+
platform::CPUPlace());
210+
211+
auto* alpha = ctx.Output<Tensor>("Alpha");
212+
alpha->mutable_data<T>(ctx.GetPlace());
213+
auto* ll = ctx.Output<Tensor>("LogLikelihood");
214+
// resize the output tensor to the correct dimension.
215+
ll->Resize({static_cast<int>(seq_num), 1});
216+
T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace());
217+
218+
for (size_t i = 0; i < seq_num; ++i) {
219+
int start_pos = static_cast<int>(in_lod[level][i]);
220+
int end_pos = static_cast<int>(in_lod[level][i + 1]);
221+
222+
const Tensor one_seq = emission_weights->Slice<T>(start_pos, end_pos);
223+
Tensor one_seq_row_max = emission_row_max.Slice<T>(start_pos, end_pos);
224+
Tensor one_seq_exps = emission_exps.Slice<T>(start_pos, end_pos);
225+
const Tensor one_seq_label = label->Slice<T>(start_pos, end_pos);
226+
Tensor one_seq_alpha = alpha->Slice<T>(start_pos, end_pos);
227+
228+
log_likelihood[i] = ForwardOneSequence(
229+
ctx.device_context(), one_seq, one_seq_row_max, one_seq_exps,
230+
(*transition_weights), transition_exps, one_seq_label, one_seq_alpha);
231+
}
232+
}
233+
234+
protected:
235+
T ForwardOneSequence(const platform::DeviceContext& ctx,
236+
const Tensor& emission, Tensor& emission_row_max,
237+
Tensor& emission_exps, const Tensor& trans_weights,
238+
Tensor& trans_weight_exps, const Tensor& label,
239+
Tensor& alpha) const {
240+
// (TODO caoying) Evaluate and optimize this.
241+
// The Eigen compution kernel will be invoked for multiple times.
242+
// Some computations regardless of sequence inforamtion could be performed
243+
// only one time for the entire batch. This potentially could be optimized.
244+
245+
auto x_dims = emission.dims();
246+
const size_t seq_length = x_dims[0];
247+
const size_t tag_num = x_dims[1];
248+
249+
T* alpha_value = alpha.data<T>();
250+
251+
auto x = EigenMatrix<T>::From(emission);
252+
auto x_row_max = EigenMatrix<T>::From(emission_row_max);
253+
const int class_dim = 1;
254+
x_row_max.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
255+
x.maximum(Eigen::DSizes<int, 1>(class_dim))
256+
.reshape(Eigen::DSizes<int, 2>(int(seq_length), 1));
257+
258+
auto x_exps = EigenMatrix<T>::From(emission_exps);
259+
x_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
260+
(x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
261+
262+
auto w = EigenMatrix<T>::From(trans_weights);
263+
auto w_exps = EigenMatrix<T>::From(trans_weight_exps);
264+
w_exps.device(*ctx.GetEigenDevice<platform::CPUPlace>()) = w.exp();
265+
// The 1st row of w are transition weights for start mask.
266+
const size_t start_ridx = 0;
267+
// The 2nd row of w are transition weights for end mask.
268+
const size_t end_ridx = 1;
269+
// Transition weights among other tags begins from the 3rd row of w.
270+
const size_t state_base_ridx = 2;
271+
272+
for (size_t i = 0; i < tag_num; ++i) {
273+
alpha_value[i] = w_exps(start_ridx, i) * x_exps(0, i);
274+
}
275+
T ll = -x_row_max(0, 1) - std::log(NormalizeL1(alpha_value, tag_num));
276+
277+
for (size_t k = 1; k < seq_length; ++k) {
278+
for (size_t i = 0; i < tag_num; ++i) {
279+
T sum = 0.;
280+
for (size_t j = 0; j < tag_num; ++j) {
281+
sum += alpha_value[(k - 1) * tag_num + j] *
282+
w_exps(j + state_base_ridx, i);
283+
}
284+
alpha_value[k * tag_num + i] = x_exps(k, i) * sum;
285+
}
286+
ll -= x_row_max(k, 1) +
287+
std::log(NormalizeL1(alpha_value + k * tag_num, tag_num));
288+
}
289+
T sum = 0.;
290+
for (size_t i = 0; i < tag_num; ++i) {
291+
sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps(end_ridx, i);
292+
}
293+
ll -= std::log(sum);
294+
295+
const int* lbl = label.data<int>();
296+
PADDLE_ENFORCE_LT(
297+
*std::max_element(lbl, lbl + seq_length), tag_num,
298+
"An invalid tag label that execesses the largest tag number.");
299+
300+
// Calculate the nominator part, which depends on the label sequence.
301+
ll += w(start_ridx, lbl[0]) + x(start_ridx, lbl[0]) +
302+
w(end_ridx, lbl[seq_length - 1]);
303+
for (size_t k = 1; k < seq_length; ++k)
304+
ll += x(k, lbl[k]) + w(lbl[k - 1], lbl[k]);
305+
return -ll;
306+
}
307+
308+
private:
309+
T NormalizeL1(T* x, size_t len) const {
310+
T sum = 0.;
311+
for (size_t i = 0; i < len; ++i) sum += x[i];
312+
// (This comment is from the old LinearChainCRFLayer.)
313+
// Right now, we just bet that sum won't be zero. If this really happens, we
314+
// will figure out what should be done then.
315+
PADDLE_ENFORCE(sum,
316+
"The unnormalized probabilites of all possible unfinished "
317+
"sequences must be greater than 0.");
318+
for (size_t i = 0; i < len; ++i) x[i] /= sum;
319+
return sum;
320+
}
321+
};
322+
166323
class LinearChainCrfGradOp : public framework::OperatorWithKernel {
167324
public:
168325
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -171,12 +328,25 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
171328
void InferShape(framework::InferShapeContext* ctx) const override {}
172329
};
173330

331+
template <typename T>
332+
class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
333+
: public framework::OpKernel<T> {
334+
public:
335+
void Compute(const framework::ExecutionContext& ctx) const override {
336+
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
337+
"This kernel only runs on CPU.");
338+
}
339+
};
340+
174341
} // namespace operators
175342
} // namespace paddle
176343

177344
namespace ops = paddle::operators;
178345
REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
179346
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
180-
REGISTER_OP_CPU_KERNEL(linear_chain_crf, ops::LinearChainCrfOpKernel<float>);
181-
REGISTER_OP_CPU_KERNEL(linear_chain_crf_grad,
182-
ops::LinearChainCrfGradOpKernel<float>);
347+
REGISTER_OP_CPU_KERNEL(
348+
linear_chain_crf,
349+
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>);
350+
REGISTER_OP_CPU_KERNEL(
351+
linear_chain_crf_grad,
352+
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/linear_chain_crf_op.h

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,31 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace operators {
2121

22-
using Tensor = framework::Tensor;
22+
using framework::Tensor;
2323
template <typename T, int MajorType = Eigen::RowMajor,
2424
typename IndexType = Eigen::DenseIndex>
2525
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
2626

27-
template <typename T>
27+
template <typename Place, typename T>
2828
class LinearChainCrfOpKernel : public framework::OpKernel<T> {
2929
public:
30-
void Compute(const framework::ExecutionContext& ctx) const override {
31-
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
32-
"This kernel only runs on CPU.");
33-
}
30+
void Compute(const framework::ExecutionContext& ctx) const override;
31+
32+
protected:
33+
T ForwardOneSequence(const platform::DeviceContext& ctx,
34+
const Tensor& emission, Tensor& emission_row_max,
35+
Tensor& emission_exps, const Tensor& trans_weights,
36+
Tensor& trans_weight_exps, const Tensor& label,
37+
Tensor& a) const;
38+
39+
private:
40+
T NormalizeL1(T* x, size_t len) const;
3441
};
3542

36-
template <typename T>
43+
template <typename Place, typename T>
3744
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
3845
public:
39-
void Compute(const framework::ExecutionContext& ctx) const override {
40-
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
41-
"This kernel only runs on CPU.");
42-
}
46+
void Compute(const framework::ExecutionContext& ctx) const override;
4347
};
4448

4549
} // namespace operators

paddle/operators/softmax_with_cross_entropy_op.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,23 @@ Because this operators performs a softmax on logits internally, it expects
6060
unscaled logits. Please do not call this op with the output of softmax operator,
6161
which will produce incorrect results.
6262
63-
This operators expects mutually exclusive hard labels, each sample in a batch
64-
is in exactly one class with probabilities 1. Each sample in the batch with one
65-
and only one label.
63+
When the attribute softLabel is set false, this operators expects mutually
64+
exclusive hard labels, each sample in a batch is in exactly one class with
65+
probabilities 1. Each sample in the batch with one and only one label.
6666
6767
Equation:
6868
6969
1) hard label (one-hot label)
7070
71-
Loss_j = -\text{Logit}_{Label_j} + \log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right), j = 1, ..., K
71+
Loss_j = \f$ -\text{Logit}_{Label_j} +
72+
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right),
73+
j = 1, ..., K $\f
7274
7375
2) soft label (a distribution over all classes)
7476
75-
Loss_j = -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i-\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right), j = 1,...,K
77+
Loss_j = \f$ -\sum_{i=0}^{K}\text{Label}_i\left(\text{Logit}_i -
78+
\log\left(\sum_{i=0}^{K}\exp(\text{Logit}_i)\right)\right),
79+
j = 1,...,K $\f
7680
7781
)DOC");
7882
}

0 commit comments

Comments
 (0)