@@ -26,19 +26,20 @@ T NormalizeL1(T* x, size_t len) {
26
26
// Right now, we just bet that sum won't be zero. If this really happens, we
27
27
// will figure out what should be done then.
28
28
PADDLE_ENFORCE (sum,
29
- " The unnormalized probabilites of all possible unfinished "
29
+ " The unnormalized probabilities of all possible unfinished "
30
30
" sequences must be greater than 0." );
31
- for (size_t i = 0 ; i < len; ++i) x[i] /= sum;
31
+ T s = 1 . / sum;
32
+ for (size_t i = 0 ; i < len; ++i) x[i] *= s;
32
33
return sum;
33
34
}
34
35
} // namespace
35
36
36
37
using framework::LoDTensor;
37
38
using framework::LoD;
38
39
39
- class LinearChainCrfOpMaker : public framework ::OpProtoAndCheckerMaker {
40
+ class LinearChainCRFOpMaker : public framework ::OpProtoAndCheckerMaker {
40
41
public:
41
- LinearChainCrfOpMaker (framework::OpProto* proto,
42
+ LinearChainCRFOpMaker (framework::OpProto* proto,
42
43
framework::OpAttrChecker* op_checker)
43
44
: OpProtoAndCheckerMaker(proto, op_checker) {
44
45
AddInput (
@@ -51,11 +52,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
51
52
AddInput (
52
53
" Transition" ,
53
54
" (Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
54
- " The learnable parameter for linear_chain_crf operator. "
55
+ " The learnable parameter for the linear_chain_crf operator. "
55
56
" See more details in the operator's comments." );
56
57
AddInput (
57
58
" Label" ,
58
- " (LoDTensor, default: LoDTensor<int>). The ground truth which is a 2-D "
59
+ " (LoDTensor, default: LoDTensor<int>). The groundtruth which is a 2-D "
59
60
" LoDTensor with shape [N x 1], where N is the total element number in "
60
61
" a mini-batch." );
61
62
AddOutput (
@@ -82,14 +83,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
82
83
.AsIntermediate ();
83
84
AddOutput (
84
85
" LogLikelihood" ,
85
- " (Tensor, default: Tensor<float>). The logarithm of the "
86
- " conditional "
86
+ " (Tensor, default: Tensor<float>). The logarithm of the conditional "
87
87
" likelihood of each training sample in a mini-batch. This is a 2-D "
88
88
" tensor with shape [S x 1], where S is the sequence number in a "
89
- " mini-batch. "
90
- " Note: S is equal to the sequence number in a mini-batch. The "
91
- " output "
92
- " is no longer a LoDTensor." );
89
+ " mini-batch. Note: S is equal to the sequence number in a mini-batch. "
90
+ " The output is no longer a LoDTensor." );
93
91
AddComment (R"DOC(
94
92
Conditional Random Field defines an undirected probabilistic graph with nodes
95
93
denoting random variables and edges denoting dependencies between these
@@ -100,11 +98,11 @@ variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
100
98
Linear chain CRF is a special case of CRF that is useful for sequence labeling
101
99
task. Sequence labeling tasks do not assume a lot of conditional
102
100
independences among inputs. They only concern about the input and the output
103
- being linear sequences. Thus, the graph model of CRF is a simple chain or
104
- a line, which results in a linear chain CRF.
101
+ being linear sequences. Thus, the graph model of such a CRF is a simple chain
102
+ or a line, which results in the linear chain CRF.
105
103
106
- This operator implements the Forward-Backward algorithm for linear chain CRF.
107
- Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
104
+ This operator implements the Forward-Backward algorithm for the linear chain
105
+ CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
108
106
109
107
Equation:
110
108
@@ -144,7 +142,7 @@ nonlinear activation.
144
142
}
145
143
};
146
144
147
- class LinearChainCrfOp : public framework ::OperatorWithKernel {
145
+ class LinearChainCRFOp : public framework ::OperatorWithKernel {
148
146
public:
149
147
using framework::OperatorWithKernel::OperatorWithKernel;
150
148
@@ -211,7 +209,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
211
209
};
212
210
213
211
template <typename T>
214
- class LinearChainCrfOpKernel <platform::CPUPlace, T>
212
+ class LinearChainCRFOpKernel <platform::CPUPlace, T>
215
213
: public framework::OpKernel<T> {
216
214
public:
217
215
void Compute (const framework::ExecutionContext& ctx) const override {
@@ -262,11 +260,11 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
262
260
w_exps.device (place) = w.exp ();
263
261
264
262
auto * alpha = ctx.Output <LoDTensor>(" Alpha" );
265
- alpha->mutable_data <T>(ctx. GetPlace ());
263
+ alpha->mutable_data <T>(platform::CPUPlace ());
266
264
auto * ll = ctx.Output <LoDTensor>(" LogLikelihood" );
267
265
// resize the output tensor to the correct dimension.
268
266
ll->Resize ({static_cast <int >(seq_num), 1 });
269
- T* log_likelihood = ll->mutable_data <T>(ctx. GetPlace ());
267
+ T* log_likelihood = ll->mutable_data <T>(platform::CPUPlace ());
270
268
for (size_t i = 0 ; i < seq_num; ++i) {
271
269
int start_pos = static_cast <int >(in_lod[level][i]);
272
270
int end_pos = static_cast <int >(in_lod[level][i + 1 ]);
@@ -322,6 +320,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
322
320
}
323
321
alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
324
322
}
323
+ // NormalizeL1 is to avoid underflow or overflow at (*).
325
324
ll -= x_row_max[k] +
326
325
std::log (NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
327
326
}
@@ -330,6 +329,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
330
329
sum += alpha_value[(seq_length - 1 ) * tag_num + i] * w_exps[tag_num + i];
331
330
}
332
331
ll -= std::log (sum);
332
+ // Now ll is equal to -log(Z).
333
333
334
334
const int * lbl = label->data <int >();
335
335
PADDLE_ENFORCE_LT (
@@ -347,7 +347,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
347
347
}
348
348
};
349
349
350
- class LinearChainCrfGradOp : public framework ::OperatorWithKernel {
350
+ class LinearChainCRFGradOp : public framework ::OperatorWithKernel {
351
351
public:
352
352
using framework::OperatorWithKernel::OperatorWithKernel;
353
353
@@ -407,11 +407,11 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
407
407
};
408
408
409
409
template <typename T>
410
- class LinearChainCrfGradOpKernel <platform::CPUPlace, T>
410
+ class LinearChainCRFGradOpKernel <platform::CPUPlace, T>
411
411
: public framework::OpKernel<T> {
412
412
public:
413
413
void Compute (const framework::ExecutionContext& ctx) const override {
414
- PADDLE_ENFORCE (platform::is_cpu_place (ctx. GetPlace ()),
414
+ PADDLE_ENFORCE (platform::is_cpu_place (platform::CPUPlace ()),
415
415
" This kernel only runs on CPU." );
416
416
auto * label = ctx.Input <LoDTensor>(" Label" );
417
417
auto * emission_exps = ctx.Input <LoDTensor>(" EmissionExps" );
@@ -493,6 +493,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
493
493
}
494
494
beta_value[k * tag_num + i] = sum;
495
495
}
496
+ // NormalizeL1 is to avoid underflow or overflow at (**).
496
497
NormalizeL1<T>(beta_value + k * tag_num, tag_num);
497
498
}
498
499
@@ -534,7 +535,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
534
535
T sum = 0 .;
535
536
for (size_t i = 0 ; i < tag_num; ++i) {
536
537
for (size_t j = 0 ; j < tag_num; ++j) {
537
- sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
538
+ sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**)
538
539
alpha_mat (k - 1 , i) * tmp_mat (k, j);
539
540
}
540
541
}
@@ -557,11 +558,11 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
557
558
} // namespace paddle
558
559
559
560
namespace ops = paddle::operators;
560
- REGISTER_OP (linear_chain_crf, ops::LinearChainCrfOp , ops::LinearChainCrfOpMaker ,
561
- linear_chain_crf_grad, ops::LinearChainCrfGradOp );
561
+ REGISTER_OP (linear_chain_crf, ops::LinearChainCRFOp , ops::LinearChainCRFOpMaker ,
562
+ linear_chain_crf_grad, ops::LinearChainCRFGradOp );
562
563
REGISTER_OP_CPU_KERNEL (
563
564
linear_chain_crf,
564
- ops::LinearChainCrfOpKernel <paddle::platform::CPUPlace, float >);
565
+ ops::LinearChainCRFOpKernel <paddle::platform::CPUPlace, float >);
565
566
REGISTER_OP_CPU_KERNEL (
566
567
linear_chain_crf_grad,
567
- ops::LinearChainCrfGradOpKernel <paddle::platform::CPUPlace, float >);
568
+ ops::LinearChainCRFGradOpKernel <paddle::platform::CPUPlace, float >);
0 commit comments