Skip to content

Commit c74107b

Browse files
committed
fix backward computation.
1 parent 6a630f2 commit c74107b

File tree

4 files changed

+54
-44
lines changed

4 files changed

+54
-44
lines changed

paddle/gserver/layers/CRFLayer.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,10 @@ void CRFLayer::backward(const UpdateCallback& callback) {
101101
: real(1.0f);
102102
instanceWeight *= coeff_;
103103

104-
MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]);
105-
grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight);
104+
if (output.grad) {
105+
MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]);
106+
grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight);
107+
}
106108
if (needWGrad) {
107109
weight_->getWGrad()->add(
108110
*crfs_[i].getWGrad(), real(1.0f), instanceWeight);

paddle/gserver/layers/LinearChainCRF.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
102102
}
103103

104104
void LinearChainCRF::backward(real* x, int* s, int length, bool needWGrad) {
105-
MatrixPtr matX = Matrix::create(x, length, numClasses_);
106105
Matrix::resizeOrCreate(matGrad_, length, numClasses_);
107106
Matrix::resizeOrCreate(beta_, length, numClasses_);
108107
real* b = b_->getData();

paddle/operators/linear_chain_crf_op.cc

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
272272
int end_pos = static_cast<int>(in_lod[level][i + 1]);
273273
if (end_pos == start_pos) {
274274
// If an empty input sequence is given, pad 0 for its cost.
275-
log_likelihood[i] = static_cast<T>(0.);
275+
log_likelihood[i] = 0.;
276276
continue;
277277
}
278278

@@ -305,7 +305,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
305305
const size_t tag_num = x_dims[1];
306306
// The 1st row of w are transition weights for start mask.
307307
// The 2nd row of w are transition weights for end mask.
308-
// Transition weights among other tags begins from the 3rd row of w.
308+
// Transition weights among other tags begin from the 3rd row of w.
309309
const size_t state_trans_base_idx = 2;
310310

311311
for (size_t i = 0; i < tag_num; ++i) {
@@ -315,7 +315,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
315315

316316
for (size_t k = 1; k < seq_length; ++k) {
317317
for (size_t i = 0; i < tag_num; ++i) {
318-
T sum = static_cast<T>(0.);
318+
T sum = 0.;
319319
for (size_t j = 0; j < tag_num; ++j) {
320320
sum += alpha_value[(k - 1) * tag_num + j] *
321321
w_exps[(j + state_trans_base_idx) * tag_num + i];
@@ -476,17 +476,17 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
476476
const size_t tag_num = x_dims[1];
477477
const size_t state_trans_base_idx = 2;
478478

479-
// Calculate the backwark vectors beta.
479+
// Calculate the backward vectors: beta.
480480
// First, calculate the initialition state.
481-
for (int i = 0; i < tag_num; ++i) {
481+
for (size_t i = 0; i < tag_num; ++i) {
482482
beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i];
483483
}
484484
NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num);
485485

486-
for (int k = seq_length - 2; k >= 0; --k) {
487-
for (int i = 0; i < tag_num; ++i) {
488-
T sum = static_cast<T>(0.);
489-
for (int j = 0; j < tag_num; ++j) {
486+
for (int k = static_cast<int>(seq_length) - 2; k >= 0; --k) {
487+
for (size_t i = 0; i < tag_num; ++i) {
488+
T sum = 0.;
489+
for (size_t j = 0; j < tag_num; ++j) {
490490
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
491491
x_exps[(k + 1) * tag_num + j] *
492492
beta_value[(k + 1) * tag_num + j];
@@ -500,13 +500,14 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
500500
auto beta_mat = EigenMatrix<T>::From(*beta);
501501
auto x_grad_mat = EigenMatrix<T>::From(*emission_grad);
502502
auto* place = ctx.GetEigenDevice<platform::CPUPlace>();
503-
x_grad_mat.device(*place) = alpha_mat * beta_mat;
504-
x_grad_mat /= x_grad_mat.sum(Eigen::DSizes<int, 1>(1))
505-
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
506-
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
507-
508-
for (int k = 0; k < seq_length; ++k) {
509-
x_grad_mat(k, label_value[k]) -= static_cast<T>(1);
503+
auto prob = alpha_mat * beta_mat;
504+
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
505+
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
506+
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
507+
x_grad_mat.device(*place) = prob / row_sum;
508+
509+
for (size_t k = 0; k < seq_length; ++k) {
510+
x_grad_mat(k, label_value[k]) -= static_cast<T>(1.);
510511
}
511512

512513
if (transition_grad) {
@@ -518,29 +519,35 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
518519
}
519520

520521
auto x_exps_mat = EigenMatrix<T>::From(*emission_exps);
521-
beta_mat = beta_mat * x_exps_mat;
522-
beta_mat /= beta_mat.sum(Eigen::DSizes<int, 1>(1))
523-
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
524-
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
525-
526-
for (int k = 1; k < seq_length; ++k) {
527-
T sum = static_cast<T>(0.);
528-
for (int i = 0; i < tag_num; ++i) {
529-
for (int j = 0; j < tag_num; ++j) {
522+
523+
// TODO(caoying): Fix this to avoid using this local variable.
524+
Tensor tmp;
525+
tmp.mutable_data<T>(beta->dims(), platform::CPUPlace());
526+
auto tmp_mat = EigenMatrix<T>::From(tmp);
527+
auto prob = beta_mat * x_exps_mat;
528+
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
529+
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
530+
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
531+
tmp_mat.device(*place) = prob / row_sum;
532+
533+
for (size_t k = 1; k < seq_length; ++k) {
534+
T sum = 0.;
535+
for (size_t i = 0; i < tag_num; ++i) {
536+
for (size_t j = 0; j < tag_num; ++j) {
530537
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
531-
alpha_mat(k - 1, i) * beta_mat(k, j);
538+
alpha_mat(k - 1, i) * tmp_mat(k, j);
532539
}
533540
}
534-
sum = static_cast<T>(1.) / sum;
535-
for (int i = 0; i < tag_num; ++i) {
536-
for (int j = 0; j < tag_num; ++j) {
541+
sum = 1. / sum;
542+
for (size_t i = 0; i < tag_num; ++i) {
543+
for (size_t j = 0; j < tag_num; ++j) {
537544
trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
538545
sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
539-
alpha_mat(k - 1, i) * beta_mat(k, j);
546+
alpha_mat(k - 1, i) * tmp_mat(k, j);
540547
}
541548
}
542-
trans_grad[label_value[k - 1] * tag_num + label_value[k]] -=
543-
static_cast<T>(1.);
549+
trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num +
550+
label_value[k]] -= static_cast<T>(1.);
544551
}
545552
}
546553
}
@@ -554,9 +561,7 @@ REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
554561
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
555562
REGISTER_OP_CPU_KERNEL(
556563
linear_chain_crf,
557-
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>,
558-
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, double>);
564+
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>);
559565
REGISTER_OP_CPU_KERNEL(
560566
linear_chain_crf_grad,
561-
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>,
562-
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, double>);
567+
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>);

python/paddle/v2/framework/tests/test_linear_chain_crf_op.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,10 @@ def crf_forward_compute(self):
8383

8484
class TestLinearChainCrfOp(OpTest):
8585
def set_test_data(self):
86-
SEQ_NUM = 2
86+
SEQ_NUM = 3
8787
TAG_NUM = 17
8888
MAX_SEQ_LEN = 5
8989

90-
random.seed(1)
9190
# the linear_chain_crf operator only supports sequence (LoD level = 1)
9291
lod = [[0]]
9392
for i in range(SEQ_NUM):
@@ -109,7 +108,6 @@ def set_test_data(self):
109108
"Transition": transition,
110109
"Label": (labels, lod)
111110
}
112-
113111
crf = LinearChainCrfForward(lod[0], emission, emission_row_max,
114112
emission_exps, transition, transition_exps,
115113
labels)
@@ -130,11 +128,17 @@ def test_check_output(self):
130128
self.check_output()
131129

132130
def test_check_grad(self):
133-
self.check_grad(["Emission", "Transition"], "LogLikelihood")
131+
self.check_grad(
132+
["Emission", "Transition"],
133+
"LogLikelihood",
134+
max_relative_error=0.05)
134135

135136
def test_check_grad_ignore_transition(self):
136137
self.check_grad(
137-
["Emission"], "LogLikelihood", no_grad_set=set("Transition"))
138+
["Emission"],
139+
"LogLikelihood",
140+
max_relative_error=0.05,
141+
no_grad_set=set("Transition"))
138142

139143

140144
if __name__ == "__main__":

0 commit comments

Comments
 (0)