@@ -165,11 +165,11 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
165
165
" Output(LogLikelihood) should be not null." );
166
166
167
167
auto emission_dims = ctx->GetInputDim (" Emission" );
168
- auto transition_dims = ctx->GetInputDim (" Transition" );
169
- auto label_dims = ctx->GetInputDim (" Label" );
170
-
171
168
PADDLE_ENFORCE_EQ (emission_dims.size (), 2UL ,
172
169
" The Input(Emission) should be a 2-D tensor." );
170
+ PADDLE_ENFORCE (emission_dims[0 ], " An empty mini-batch is not allowed." );
171
+
172
+ auto transition_dims = ctx->GetInputDim (" Transition" );
173
173
PADDLE_ENFORCE_EQ (transition_dims.size (), 2UL ,
174
174
" The Input(Transition) should be a 2-D tensor." );
175
175
PADDLE_ENFORCE_EQ (
@@ -180,6 +180,8 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
180
180
emission_dims[1 ], transition_dims[1 ],
181
181
" The 2nd dimension of the Input(Emission) and the Input(Transition) "
182
182
" should be equal to the tag number." );
183
+
184
+ auto label_dims = ctx->GetInputDim (" Label" );
183
185
PADDLE_ENFORCE (label_dims.size () == 2UL && label_dims[1 ] == 1UL ,
184
186
" The Input(Label) should be a 2-D tensor with the 2nd "
185
187
" dimensions fixed to 1." );
@@ -204,7 +206,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
204
206
// operator is determined by its input "Emission".
205
207
framework::DataType IndicateDataType (
206
208
const framework::ExecutionContext& ctx) const override {
207
- return framework::ToDataType (ctx.Input <Tensor >(" Emission" )->type ());
209
+ return framework::ToDataType (ctx.Input <LoDTensor >(" Emission" )->type ());
208
210
}
209
211
};
210
212
@@ -224,6 +226,8 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
224
226
auto * label = ctx.Input <LoDTensor>(" Label" );
225
227
226
228
auto in_lod = emission_weights->lod ();
229
+ PADDLE_ENFORCE (in_lod.size (), " Input(Emission) is not a sequence." );
230
+
227
231
// TODO(caoying) The checks related to LoD information should be
228
232
// moved into InferShape once after the InferShape is refactored.
229
233
PADDLE_ENFORCE_EQ (emission_weights->NumLevels (), 1UL ,
@@ -266,12 +270,17 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
266
270
for (size_t i = 0 ; i < seq_num; ++i) {
267
271
int start_pos = static_cast <int >(in_lod[level][i]);
268
272
int end_pos = static_cast <int >(in_lod[level][i + 1 ]);
273
+ if (end_pos == start_pos) {
274
+ // If an empty input sequence is given, pad 0 for its cost.
275
+ log_likelihood[i] = static_cast <T>(0 .);
276
+ continue ;
277
+ }
269
278
270
- const Tensor one_seq = emission_weights->Slice <T> (start_pos, end_pos);
271
- Tensor one_seq_row_max = emission_row_max.Slice <T> (start_pos, end_pos);
272
- Tensor one_seq_exps = emission_exps->Slice <T> (start_pos, end_pos);
273
- const Tensor one_seq_label = label->Slice <T> (start_pos, end_pos);
274
- Tensor one_seq_alpha = alpha->Slice <T> (start_pos, end_pos);
279
+ const Tensor one_seq = emission_weights->Slice (start_pos, end_pos);
280
+ Tensor one_seq_row_max = emission_row_max.Slice (start_pos, end_pos);
281
+ Tensor one_seq_exps = emission_exps->Slice (start_pos, end_pos);
282
+ const Tensor one_seq_label = label->Slice (start_pos, end_pos);
283
+ Tensor one_seq_alpha = alpha->Slice (start_pos, end_pos);
275
284
276
285
log_likelihood[i] = ForwardOneSequence (
277
286
&one_seq, &one_seq_row_max, &one_seq_exps, transition_weights,
@@ -306,7 +315,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
306
315
307
316
for (size_t k = 1 ; k < seq_length; ++k) {
308
317
for (size_t i = 0 ; i < tag_num; ++i) {
309
- T sum = 0 . ;
318
+ T sum = static_cast <T>( 0 .) ;
310
319
for (size_t j = 0 ; j < tag_num; ++j) {
311
320
sum += alpha_value[(k - 1 ) * tag_num + j] *
312
321
w_exps[(j + state_trans_base_idx) * tag_num + i];
@@ -326,11 +335,14 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
326
335
PADDLE_ENFORCE_LT (
327
336
*std::max_element (lbl, lbl + seq_length), tag_num,
328
337
" An invalid tag label that execesses the largest tag number." );
338
+
329
339
// Calculate the nominator part, which depends on the label sequence.
330
340
ll += w[lbl[0 ]] /* start transition*/ + x[lbl[0 ]] +
331
341
w[tag_num + lbl[seq_length - 1 ]] /* end transition*/ ;
332
- for (size_t k = 1 ; k < seq_length; ++k)
333
- ll += x[k * tag_num + lbl[k]] + w[lbl[k - 1 ] * tag_num + lbl[k]];
342
+ for (size_t k = 1 ; k < seq_length; ++k) {
343
+ ll += x[k * tag_num + lbl[k]] +
344
+ w[(lbl[k - 1 ] + state_trans_base_idx) * tag_num + lbl[k]];
345
+ }
334
346
return -ll;
335
347
}
336
348
};
@@ -353,12 +365,13 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
353
365
" Output(Transition@GRAD) should be not null." );
354
366
355
367
auto emission_exps_dims = ctx->GetInputDim (" EmissionExps" );
356
- auto transition_exps_dims =
357
- ctx->GetInputDim (framework::GradVarName (" TransitionExps" ));
358
- auto label_dims = ctx->GetInputDim (" Label" );
359
-
360
368
PADDLE_ENFORCE_EQ (emission_exps_dims.size (), 2UL ,
361
369
" The Input(EmissionExps) should be a 2-D tensor." );
370
+ PADDLE_ENFORCE (emission_exps_dims[0 ],
371
+ " An empty mini-batch is not allowed." );
372
+
373
+ auto transition_exps_dims =
374
+ ctx->GetInputDim (framework::GradVarName (" TransitionExps" ));
362
375
PADDLE_ENFORCE_EQ (transition_exps_dims.size (), 2UL ,
363
376
" The Input(TransitionExps) should be a 2-D tensor." );
364
377
PADDLE_ENFORCE_EQ (
@@ -369,6 +382,8 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
369
382
emission_exps_dims[1 ], transition_exps_dims[1 ],
370
383
" The 2nd dimension of the Input(EmissionExps) and the "
371
384
" Input(TransitionExps) should be equal to the tag number." );
385
+
386
+ auto label_dims = ctx->GetInputDim (" Label" );
372
387
PADDLE_ENFORCE (label_dims.size () == 2UL && label_dims[1 ] == 1UL ,
373
388
" The Input(Label) should be a 2-D tensor with the 2nd "
374
389
" dimensions fixed to 1." );
@@ -381,6 +396,14 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
381
396
ctx->SetOutputDim (framework::GradVarName (" Transition" ),
382
397
transition_exps_dims);
383
398
}
399
+
400
+ protected:
401
+ // Explicitly set that the data type of output of the linear_chain_crf_grad
402
+ // operator is determined by its input "EmissionExps".
403
+ framework::DataType IndicateDataType (
404
+ const framework::ExecutionContext& ctx) const override {
405
+ return framework::ToDataType (ctx.Input <LoDTensor>(" EmissionExps" )->type ());
406
+ }
384
407
};
385
408
386
409
template <typename T>
@@ -390,12 +413,12 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
390
413
void Compute (const framework::ExecutionContext& ctx) const override {
391
414
PADDLE_ENFORCE (platform::is_cpu_place (ctx.GetPlace ()),
392
415
" This kernel only runs on CPU." );
393
- auto * ll_grad =
394
- ctx.Input <LoDTensor>(framework::GradVarName (" LogLikelihood" ));
395
416
auto * label = ctx.Input <LoDTensor>(" Label" );
396
417
auto * emission_exps = ctx.Input <LoDTensor>(" EmissionExps" );
397
418
auto * transition_exps = ctx.Input <Tensor>(" TransitionExps" );
398
- auto * alpha = ctx.Input <Tensor>(" Alpha" );
419
+ auto * alpha = ctx.Input <LoDTensor>(" Alpha" );
420
+ const T* ll_grad =
421
+ ctx.Input <Tensor>(framework::GradVarName (" LogLikelihood" ))->data <T>();
399
422
400
423
auto * emission_grad =
401
424
ctx.Output <Tensor>(framework::GradVarName (" Emission" ));
@@ -413,34 +436,31 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
413
436
Tensor beta;
414
437
beta.mutable_data <T>(emission_dims, platform::CPUPlace ());
415
438
416
- auto place = ctx.GetEigenDevice <platform::CPUPlace>();
417
- auto x_grad = EigenMatrix<T>::From (*emission_grad);
418
- auto out_grad = EigenMatrix<T>::From (*ll_grad);
419
- x_grad.device (place) =
420
- x_grad * out_grad.broadcast (Eigen::DSizes<int , 2 >(1 , emission_dims[1 ]));
421
-
422
439
const size_t level = 0 ; // currently, only support sequence.
423
- auto lod = emission_exps->lod ();
440
+ auto lod = label->lod ();
441
+ PADDLE_ENFORCE (lod.size (), " Input(Label) is not a sequence." );
442
+
424
443
for (size_t i = 0 ; i < lod[level].size () - 1 ; ++i) {
425
444
int start_pos = static_cast <int >(lod[level][i]);
426
445
int end_pos = static_cast <int >(lod[level][i + 1 ]);
446
+ if (end_pos == start_pos) continue ;
427
447
428
448
const Tensor one_seq_emission_exps =
429
- emission_exps->Slice <T> (start_pos, end_pos);
430
- const Tensor one_seq_label = label->Slice <T> (start_pos, end_pos);
431
- const Tensor one_seq_alpha = alpha->Slice <T> (start_pos, end_pos);
432
- Tensor one_seq_beta = beta.Slice <T> (start_pos, end_pos);
433
- Tensor one_seq_emission_grad =
434
- emission_grad-> Slice <T>(start_pos, end_pos);
435
-
436
- BackwardOneSequence (ctx. device_context (), &one_seq_emission_exps,
437
- transition_exps, &one_seq_alpha, &one_seq_label,
438
- &one_seq_beta, trans_grad, &one_seq_emission_grad);
449
+ emission_exps->Slice (start_pos, end_pos);
450
+ const Tensor one_seq_label = label->Slice (start_pos, end_pos);
451
+ const Tensor one_seq_alpha = alpha->Slice (start_pos, end_pos);
452
+ Tensor one_seq_beta = beta.Slice (start_pos, end_pos);
453
+ Tensor one_seq_emission_grad = emission_grad-> Slice (start_pos, end_pos);
454
+
455
+ BackwardOneSequence (ctx. device_context (), ll_grad[i],
456
+ &one_seq_emission_exps, transition_exps ,
457
+ &one_seq_alpha, &one_seq_label, &one_seq_beta ,
458
+ trans_grad, &one_seq_emission_grad);
439
459
}
440
460
}
441
461
442
462
protected:
443
- void BackwardOneSequence (const platform::DeviceContext& ctx,
463
+ void BackwardOneSequence (const platform::DeviceContext& ctx, const T ll_grad,
444
464
const Tensor* emission_exps,
445
465
const Tensor* transition_exps, const Tensor* alpha,
446
466
const Tensor* label, Tensor* beta,
@@ -457,12 +477,15 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
457
477
const size_t state_trans_base_idx = 2 ;
458
478
459
479
// Calculate the backwark vectors beta.
460
- for (int i = 0 ; i < tag_num; ++i)
480
+ // First, calculate the initialition state.
481
+ for (int i = 0 ; i < tag_num; ++i) {
461
482
beta_value[(seq_length - 1 ) * tag_num + i] = w_exps[tag_num + i];
483
+ }
462
484
NormalizeL1<T>(beta_value + (seq_length - 1 ) * tag_num, tag_num);
485
+
463
486
for (int k = seq_length - 2 ; k >= 0 ; --k) {
464
487
for (int i = 0 ; i < tag_num; ++i) {
465
- T sum = 0 . ;
488
+ T sum = static_cast <T>( 0 .) ;
466
489
for (int j = 0 ; j < tag_num; ++j) {
467
490
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
468
491
x_exps[(k + 1 ) * tag_num + j] *
@@ -476,15 +499,17 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
476
499
auto alpha_mat = EigenMatrix<T>::From (*alpha);
477
500
auto beta_mat = EigenMatrix<T>::From (*beta);
478
501
auto x_grad_mat = EigenMatrix<T>::From (*emission_grad);
502
+ x_grad_mat.setConstant (ll_grad);
479
503
480
504
auto * place = ctx.GetEigenDevice <platform::CPUPlace>();
481
505
x_grad_mat.device (*place) = alpha_mat * beta_mat;
482
506
x_grad_mat /= x_grad_mat.sum (Eigen::DSizes<int , 1 >(1 ))
483
507
.reshape (Eigen::DSizes<int , 2 >(seq_length, 1 ))
484
508
.broadcast (Eigen::DSizes<int , 2 >(1 , tag_num));
485
509
486
- for (int k = 0 ; k < seq_length; ++k)
510
+ for (int k = 0 ; k < seq_length; ++k) {
487
511
x_grad_mat (k, label_value[k]) -= static_cast <T>(1 );
512
+ }
488
513
489
514
if (transition_grad) {
490
515
T* trans_grad = transition_grad->data <T>();
@@ -501,20 +526,23 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
501
526
.broadcast (Eigen::DSizes<int , 2 >(1 , tag_num));
502
527
503
528
for (int k = 1 ; k < seq_length; ++k) {
504
- T sum = 0 . ;
529
+ T sum = static_cast <T>( 0 .) ;
505
530
for (int i = 0 ; i < tag_num; ++i) {
506
- for (int j = 0 ; j < tag_num; ++j)
507
- sum += x_exps_mat (i, j) * alpha_mat (k - 1 , i) * beta_mat (k, j);
531
+ for (int j = 0 ; j < tag_num; ++j) {
532
+ sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
533
+ alpha_mat (k - 1 , i) * beta_mat (k, j);
534
+ }
508
535
}
509
- sum = static_cast <T>(1 ) / sum;
536
+ sum = static_cast <T>(1 . ) / sum;
510
537
for (int i = 0 ; i < tag_num; ++i) {
511
538
for (int j = 0 ; j < tag_num; ++j) {
512
- trans_grad[(i + 2 ) * tag_num + j] +=
513
- sum * x_exps_mat (i, j) * alpha_mat (k - 1 , i) * beta_mat (k, j);
539
+ trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
540
+ sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
541
+ alpha_mat (k - 1 , i) * beta_mat (k, j);
514
542
}
515
543
}
516
544
trans_grad[label_value[k - 1 ] * tag_num + label_value[k]] -=
517
- static_cast <T>(1 );
545
+ static_cast <T>(1 . );
518
546
}
519
547
}
520
548
}
0 commit comments