Skip to content

Commit 51c4585

Browse files
authored
Merge pull request #678 from pengli09/fix-crf-weight-and-coeff-bug
Fix bug in processing instance weight and coeff in CRFLayer
2 parents b07be67 + e80a3cf commit 51c4585

File tree

11 files changed

+262
-86
lines changed

11 files changed

+262
-86
lines changed

paddle/gserver/layers/CRFDecodingLayer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ bool CRFDecodingLayer::init(const LayerMap& layerMap,
2424
return false;
2525
}
2626
crf_.reset(new LinearChainCRF(
27-
numClasses_, parameter_->getBuf(PARAMETER_VALUE)->getData(), nullptr));
27+
numClasses_, parameter_->getBuf(PARAMETER_VALUE)->getData()));
2828
return true;
2929
}
3030

paddle/gserver/layers/CRFLayer.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ bool CRFLayer::init(const LayerMap& layerMap,
4242
CHECK_EQ(parameters_[0]->getSize(), numClasses_ * (numClasses_ + 2));
4343

4444
parameter_ = parameters_[0];
45+
weight_.reset(new Weight(numClasses_ + 2, numClasses_, parameter_));
4546

4647
// We don't need sequenceStartPositions because each sample of output_ is
4748
// for the cost of one sequence.
@@ -69,11 +70,7 @@ void CRFLayer::forward(PassType passType) {
6970

7071
for (size_t i = 0; i < numSequences; ++i) {
7172
if (i >= crfs_.size()) {
72-
crfs_.emplace_back(numClasses_,
73-
parameter_->getBuf(PARAMETER_VALUE)->getData(),
74-
parameter_->getBuf(PARAMETER_GRADIENT)
75-
? parameter_->getBuf(PARAMETER_GRADIENT)->getData()
76-
: nullptr);
73+
crfs_.emplace_back(numClasses_, weight_->getW()->getData());
7774
}
7875
output_.value->getData()[i] =
7976
crfs_[i].forward(output.value->getData() + numClasses_ * starts[i],
@@ -93,22 +90,25 @@ void CRFLayer::backward(const UpdateCallback& callback) {
9390
const int* starts = label.sequenceStartPositions->getData(false);
9491
int numSequences = label.sequenceStartPositions->getSize() - 1;
9592

93+
bool needWGrad = weight_->getWGrad() ? true : false;
9694
for (int i = 0; i < numSequences; ++i) {
9795
crfs_[i].backward(output.value->getData() + numClasses_ * starts[i],
98-
output.grad->getData() + numClasses_ * starts[i],
9996
label.ids->getData() + starts[i],
100-
starts[i + 1] - starts[i]);
101-
if (weightLayer_) {
102-
real weight = getInputValue(*weightLayer_)->getElement(i, 0);
103-
MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]);
104-
grad->mulScalar(weight);
97+
starts[i + 1] - starts[i],
98+
needWGrad);
99+
real instanceWeight = weightLayer_
100+
? getInputValue(*weightLayer_)->getElement(i, 0)
101+
: real(1.0f);
102+
instanceWeight *= coeff_;
103+
104+
MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]);
105+
grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight);
106+
if (needWGrad) {
107+
weight_->getWGrad()->add(
108+
*crfs_[i].getWGrad(), real(1.0f), instanceWeight);
105109
}
106110
}
107111

108-
if (coeff_ != real(1.0f)) {
109-
output.grad->mulScalar(coeff_);
110-
}
111-
112112
parameter_->incUpdate(callback);
113113
}
114114

paddle/gserver/layers/CRFLayer.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ class CRFLayer : public Layer {
3838
size_t numClasses_;
3939
ParameterPtr parameter_;
4040
std::vector<LinearChainCRF> crfs_;
41-
LayerPtr weightLayer_; // weight for each sequence
42-
real coeff_; // weight for the layer
41+
LayerPtr weightLayer_; // weight for each sequence
42+
std::unique_ptr<Weight> weight_; // parameters
43+
real coeff_; // weight for the layer
4344
};
4445

4546
} // namespace paddle

paddle/gserver/layers/LinearChainCRF.cpp

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,12 @@ limitations under the License. */
1717

1818
namespace paddle {
1919

20-
LinearChainCRF::LinearChainCRF(int numClasses, real* para, real* grad)
20+
LinearChainCRF::LinearChainCRF(int numClasses, real* para)
2121
: numClasses_(numClasses) {
2222
a_ = Matrix::create(para, 1, numClasses_);
2323
b_ = Matrix::create(para + numClasses_, 1, numClasses_);
2424
w_ = Matrix::create(para + 2 * numClasses_, numClasses_, numClasses_);
2525

26-
if (grad) {
27-
da_ = Matrix::create(grad, 1, numClasses_);
28-
db_ = Matrix::create(grad + numClasses_, 1, numClasses_);
29-
dw_ = Matrix::create(grad + 2 * numClasses_, numClasses_, numClasses_);
30-
}
31-
3226
ones_ = Matrix::create(1, numClasses_);
3327
ones_->one();
3428

@@ -107,19 +101,24 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
107101
return -ll;
108102
}
109103

110-
void LinearChainCRF::backward(real* x, real* dx, int* s, int length) {
104+
void LinearChainCRF::backward(real* x, int* s, int length, bool needWGrad) {
111105
MatrixPtr matX = Matrix::create(x, length, numClasses_);
112-
MatrixPtr matDX = Matrix::create(dx, length, numClasses_);
113-
MatrixPtr matGrad = Matrix::create(length, numClasses_);
106+
Matrix::resizeOrCreate(matGrad_, length, numClasses_);
114107
Matrix::resizeOrCreate(beta_, length, numClasses_);
115108
real* b = b_->getData();
116-
real* dw = dw_ ? dw_->getData() : nullptr;
109+
if (needWGrad) {
110+
Matrix::resizeOrCreate(matWGrad_, numClasses_ + 2, numClasses_);
111+
matWGrad_->zeroMem();
112+
da_ = matWGrad_->subRowMatrix(0, 1);
113+
db_ = matWGrad_->subRowMatrix(1, 2);
114+
dw_ = matWGrad_->subRowMatrix(2, numClasses_ + 2);
115+
}
117116

118117
real* alpha = alpha_->getData();
119118
real* beta = beta_->getData();
120119
real* expW = expW_->getData();
121120
real* expX = expX_->getData();
122-
real* grad = matGrad->getData();
121+
real* grad = matGrad_->getData();
123122

124123
for (int i = 0; i < numClasses_; ++i) {
125124
beta[(length - 1) * numClasses_ + i] = exp(b[i]);
@@ -140,39 +139,38 @@ void LinearChainCRF::backward(real* x, real* dx, int* s, int length) {
140139
normalizeL1(beta + k * numClasses_, numClasses_);
141140
}
142141

143-
matGrad->dotMul(*alpha_, *beta_);
144-
matGrad->rowNormalizeL1(*matGrad);
142+
matGrad_->dotMul(*alpha_, *beta_);
143+
matGrad_->rowNormalizeL1(*matGrad_);
145144
for (int k = 0; k < length; ++k) {
146145
grad[k * numClasses_ + s[k]] -= (real)1;
147146
}
148-
matDX->add(*matGrad);
149-
if (da_) {
150-
da_->add(*matGrad->subMatrix(/* startRow= */ 0, /* numRows= */ 1));
151-
}
152-
if (db_) {
153-
db_->add(*matGrad->subMatrix(/* startRow= */ length - 1, 1));
154-
}
155147

156-
beta_->dotMul(*beta_, *expX_);
157-
beta_->rowNormalizeL1(*beta_);
148+
if (needWGrad) {
149+
da_->add(*matGrad_->subMatrix(/* startRow= */ 0, /* numRows= */ 1));
150+
db_->add(*matGrad_->subMatrix(/* startRow= */ length - 1, 1));
158151

159-
for (int k = 1; dw && k < length; ++k) {
160-
real sum = 0;
161-
for (int i = 0; i < numClasses_; ++i) {
162-
for (int j = 0; j < numClasses_; ++j) {
163-
sum += expW[i * numClasses_ + j] * alpha[(k - 1) * numClasses_ + i] *
164-
beta[k * numClasses_ + j];
152+
beta_->dotMul(*beta_, *expX_);
153+
beta_->rowNormalizeL1(*beta_);
154+
155+
real* dw = dw_->getData();
156+
for (int k = 1; k < length; ++k) {
157+
real sum = 0;
158+
for (int i = 0; i < numClasses_; ++i) {
159+
for (int j = 0; j < numClasses_; ++j) {
160+
sum += expW[i * numClasses_ + j] * alpha[(k - 1) * numClasses_ + i] *
161+
beta[k * numClasses_ + j];
162+
}
165163
}
166-
}
167-
sum = 1 / sum;
168-
for (int i = 0; i < numClasses_; ++i) {
169-
for (int j = 0; j < numClasses_; ++j) {
170-
dw[i * numClasses_ + j] += sum * expW[i * numClasses_ + j] *
171-
alpha[(k - 1) * numClasses_ + i] *
172-
beta[k * numClasses_ + j];
164+
sum = 1 / sum;
165+
for (int i = 0; i < numClasses_; ++i) {
166+
for (int j = 0; j < numClasses_; ++j) {
167+
dw[i * numClasses_ + j] += sum * expW[i * numClasses_ + j] *
168+
alpha[(k - 1) * numClasses_ + i] *
169+
beta[k * numClasses_ + j];
170+
}
173171
}
172+
dw[s[k - 1] * numClasses_ + s[k]] -= (real)1;
174173
}
175-
dw[s[k - 1] * numClasses_ + s[k]] -= (real)1;
176174
}
177175
}
178176

paddle/gserver/layers/LinearChainCRF.h

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ namespace paddle {
2121
class LinearChainCRF {
2222
public:
2323
/**
24-
* The size of para and grad must be \f$(numClasses + 2) * numClasses\f$.
24+
* The size of para must be \f$(numClasses + 2) * numClasses\f$.
2525
* The first numClasses values of para are for starting weights (\f$a\f$).
2626
* The next numClasses values of para are for ending weights (\f$b\f$),
2727
* The remaning values are for transition weights (\f$w\f$).
@@ -34,7 +34,7 @@ class LinearChainCRF {
3434
* all possible
3535
* sequences is \f$1\f$, and \f$x\f$ is the input feature to the CRF.
3636
*/
37-
LinearChainCRF(int numClasses, real* para, real* grad);
37+
LinearChainCRF(int numClasses, real* para);
3838

3939
/**
4040
* Calculate the negative log likelihood of s given x.
@@ -45,29 +45,45 @@ class LinearChainCRF {
4545

4646
/**
4747
* Calculate the gradient with respect to x, a, b, and w.
48-
* The gradient of x will be stored in dx.
4948
* backward() can only be called after a corresponding call to forward() with
5049
* the same x, s and length.
51-
* @note The gradient is added to dx and grad (provided at constructor).
50+
* The gradient with respect to a, b, and w will not be calculated if
51+
* needWGrad is false.
52+
* @note Please call getWGrad() and getXGrad() to get the gradient with
53+
* respect to (a, b, w) and x respectively.
5254
*/
53-
void backward(real* x, real* dx, int* s, int length);
55+
void backward(real* x, int* s, int length, bool needWGrad);
5456

5557
/**
5658
* Find the most probable sequence given x. The result will be stored in s.
5759
*/
5860
void decode(real* x, int* s, int length);
5961

62+
/*
63+
* Return the gradient with respect to (a, b, w). It can only be called after
64+
* a corresponding call to backward().
65+
*/
66+
MatrixPtr getWGrad() { return matWGrad_; }
67+
68+
/*
69+
* Return the gradient with respect to x. It can only be called after a
70+
* corresponding call to backward().
71+
*/
72+
MatrixPtr getXGrad() { return matGrad_; }
73+
6074
protected:
6175
int numClasses_;
6276
MatrixPtr a_;
6377
MatrixPtr b_;
6478
MatrixPtr w_;
79+
MatrixPtr matWGrad_;
6580
MatrixPtr da_;
6681
MatrixPtr db_;
6782
MatrixPtr dw_;
6883
MatrixPtr ones_;
6984

7085
MatrixPtr expX_;
86+
MatrixPtr matGrad_;
7187
MatrixPtr alpha_;
7288
MatrixPtr beta_;
7389
MatrixPtr maxX_;

paddle/gserver/tests/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@ add_unittest_without_exec(test_LayerGrad
1818
add_test(NAME test_LayerGrad
1919
COMMAND test_LayerGrad)
2020

21+
################ test_CRFLayerGrad ####################
22+
add_unittest_without_exec(test_CRFLayerGrad
23+
test_CRFLayerGrad.cpp
24+
LayerGradUtil.cpp)
25+
add_test(NAME test_CRFLayerGrad
26+
COMMAND test_CRFLayerGrad)
27+
28+
2129
add_unittest_without_exec(test_ActivationGrad
2230
test_ActivationGrad.cpp
2331
LayerGradUtil.cpp)

0 commit comments

Comments
 (0)