@@ -272,7 +272,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
272
272
int end_pos = static_cast <int >(in_lod[level][i + 1 ]);
273
273
if (end_pos == start_pos) {
274
274
// If an empty input sequence is given, pad 0 for its cost.
275
- log_likelihood[i] = static_cast <T>( 0 .) ;
275
+ log_likelihood[i] = 0 . ;
276
276
continue ;
277
277
}
278
278
@@ -305,7 +305,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
305
305
const size_t tag_num = x_dims[1 ];
306
306
// The 1st row of w are transition weights for start mask.
307
307
// The 2nd row of w are transition weights for end mask.
308
- // Transition weights among other tags begins from the 3rd row of w.
308
+ // Transition weights among other tags begin from the 3rd row of w.
309
309
const size_t state_trans_base_idx = 2 ;
310
310
311
311
for (size_t i = 0 ; i < tag_num; ++i) {
@@ -315,7 +315,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
315
315
316
316
for (size_t k = 1 ; k < seq_length; ++k) {
317
317
for (size_t i = 0 ; i < tag_num; ++i) {
318
- T sum = static_cast <T>( 0 .) ;
318
+ T sum = 0 . ;
319
319
for (size_t j = 0 ; j < tag_num; ++j) {
320
320
sum += alpha_value[(k - 1 ) * tag_num + j] *
321
321
w_exps[(j + state_trans_base_idx) * tag_num + i];
@@ -476,17 +476,17 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
476
476
const size_t tag_num = x_dims[1 ];
477
477
const size_t state_trans_base_idx = 2 ;
478
478
479
- // Calculate the backwark vectors beta.
479
+ // Calculate the backward vectors: beta.
480
480
// First, calculate the initialition state.
481
- for (int i = 0 ; i < tag_num; ++i) {
481
+ for (size_t i = 0 ; i < tag_num; ++i) {
482
482
beta_value[(seq_length - 1 ) * tag_num + i] = w_exps[tag_num + i];
483
483
}
484
484
NormalizeL1<T>(beta_value + (seq_length - 1 ) * tag_num, tag_num);
485
485
486
- for (int k = seq_length - 2 ; k >= 0 ; --k) {
487
- for (int i = 0 ; i < tag_num; ++i) {
488
- T sum = static_cast <T>( 0 .) ;
489
- for (int j = 0 ; j < tag_num; ++j) {
486
+ for (int k = static_cast < int >( seq_length) - 2 ; k >= 0 ; --k) {
487
+ for (size_t i = 0 ; i < tag_num; ++i) {
488
+ T sum = 0 . ;
489
+ for (size_t j = 0 ; j < tag_num; ++j) {
490
490
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
491
491
x_exps[(k + 1 ) * tag_num + j] *
492
492
beta_value[(k + 1 ) * tag_num + j];
@@ -500,13 +500,14 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
500
500
auto beta_mat = EigenMatrix<T>::From (*beta);
501
501
auto x_grad_mat = EigenMatrix<T>::From (*emission_grad);
502
502
auto * place = ctx.GetEigenDevice <platform::CPUPlace>();
503
- x_grad_mat.device (*place) = alpha_mat * beta_mat;
504
- x_grad_mat /= x_grad_mat.sum (Eigen::DSizes<int , 1 >(1 ))
505
- .reshape (Eigen::DSizes<int , 2 >(seq_length, 1 ))
506
- .broadcast (Eigen::DSizes<int , 2 >(1 , tag_num));
507
-
508
- for (int k = 0 ; k < seq_length; ++k) {
509
- x_grad_mat (k, label_value[k]) -= static_cast <T>(1 );
503
+ auto prob = alpha_mat * beta_mat;
504
+ auto row_sum = prob.sum (Eigen::DSizes<int , 1 >(1 ))
505
+ .reshape (Eigen::DSizes<int , 2 >(seq_length, 1 ))
506
+ .broadcast (Eigen::DSizes<int , 2 >(1 , tag_num));
507
+ x_grad_mat.device (*place) = prob / row_sum;
508
+
509
+ for (size_t k = 0 ; k < seq_length; ++k) {
510
+ x_grad_mat (k, label_value[k]) -= static_cast <T>(1 .);
510
511
}
511
512
512
513
if (transition_grad) {
@@ -518,29 +519,35 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
518
519
}
519
520
520
521
auto x_exps_mat = EigenMatrix<T>::From (*emission_exps);
521
- beta_mat = beta_mat * x_exps_mat;
522
- beta_mat /= beta_mat.sum (Eigen::DSizes<int , 1 >(1 ))
523
- .reshape (Eigen::DSizes<int , 2 >(seq_length, 1 ))
524
- .broadcast (Eigen::DSizes<int , 2 >(1 , tag_num));
525
-
526
- for (int k = 1 ; k < seq_length; ++k) {
527
- T sum = static_cast <T>(0 .);
528
- for (int i = 0 ; i < tag_num; ++i) {
529
- for (int j = 0 ; j < tag_num; ++j) {
522
+
523
+ // TODO(caoying): Fix this to avoid using this local variable.
524
+ Tensor tmp;
525
+ tmp.mutable_data <T>(beta->dims (), platform::CPUPlace ());
526
+ auto tmp_mat = EigenMatrix<T>::From (tmp);
527
+ auto prob = beta_mat * x_exps_mat;
528
+ auto row_sum = prob.sum (Eigen::DSizes<int , 1 >(1 ))
529
+ .reshape (Eigen::DSizes<int , 2 >(seq_length, 1 ))
530
+ .broadcast (Eigen::DSizes<int , 2 >(1 , tag_num));
531
+ tmp_mat.device (*place) = prob / row_sum;
532
+
533
+ for (size_t k = 1 ; k < seq_length; ++k) {
534
+ T sum = 0 .;
535
+ for (size_t i = 0 ; i < tag_num; ++i) {
536
+ for (size_t j = 0 ; j < tag_num; ++j) {
530
537
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
531
- alpha_mat (k - 1 , i) * beta_mat (k, j);
538
+ alpha_mat (k - 1 , i) * tmp_mat (k, j);
532
539
}
533
540
}
534
- sum = static_cast <T>( 1 .) / sum;
535
- for (int i = 0 ; i < tag_num; ++i) {
536
- for (int j = 0 ; j < tag_num; ++j) {
541
+ sum = 1 . / sum;
542
+ for (size_t i = 0 ; i < tag_num; ++i) {
543
+ for (size_t j = 0 ; j < tag_num; ++j) {
537
544
trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
538
545
sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
539
- alpha_mat (k - 1 , i) * beta_mat (k, j);
546
+ alpha_mat (k - 1 , i) * tmp_mat (k, j);
540
547
}
541
548
}
542
- trans_grad[label_value[k - 1 ] * tag_num + label_value[k]] -=
543
- static_cast <T>(1 .);
549
+ trans_grad[( label_value[k - 1 ] + state_trans_base_idx) * tag_num +
550
+ label_value[k]] -= static_cast <T>(1 .);
544
551
}
545
552
}
546
553
}
@@ -554,9 +561,7 @@ REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
554
561
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
555
562
REGISTER_OP_CPU_KERNEL (
556
563
linear_chain_crf,
557
- ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float >,
558
- ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, double >);
564
+ ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float >);
559
565
REGISTER_OP_CPU_KERNEL (
560
566
linear_chain_crf_grad,
561
- ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float >,
562
- ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, double >);
567
+ ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float >);
0 commit comments