@@ -17,18 +17,12 @@ limitations under the License. */
17
17
18
18
namespace paddle {
19
19
20
- LinearChainCRF::LinearChainCRF (int numClasses, real* para, real* grad )
20
+ LinearChainCRF::LinearChainCRF (int numClasses, real* para)
21
21
: numClasses_(numClasses) {
22
22
a_ = Matrix::create (para, 1 , numClasses_);
23
23
b_ = Matrix::create (para + numClasses_, 1 , numClasses_);
24
24
w_ = Matrix::create (para + 2 * numClasses_, numClasses_, numClasses_);
25
25
26
- if (grad) {
27
- da_ = Matrix::create (grad, 1 , numClasses_);
28
- db_ = Matrix::create (grad + numClasses_, 1 , numClasses_);
29
- dw_ = Matrix::create (grad + 2 * numClasses_, numClasses_, numClasses_);
30
- }
31
-
32
26
ones_ = Matrix::create (1 , numClasses_);
33
27
ones_->one ();
34
28
@@ -107,19 +101,24 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
107
101
return -ll;
108
102
}
109
103
110
- void LinearChainCRF::backward (real* x, real* dx, int * s, int length) {
104
+ void LinearChainCRF::backward (real* x, int * s, int length, bool needWGrad ) {
111
105
MatrixPtr matX = Matrix::create (x, length, numClasses_);
112
- MatrixPtr matDX = Matrix::create (dx, length, numClasses_);
113
- MatrixPtr matGrad = Matrix::create (length, numClasses_);
106
+ Matrix::resizeOrCreate (matGrad_, length, numClasses_);
114
107
Matrix::resizeOrCreate (beta_, length, numClasses_);
115
108
real* b = b_->getData ();
116
- real* dw = dw_ ? dw_->getData () : nullptr ;
109
+ if (needWGrad) {
110
+ Matrix::resizeOrCreate (matWGrad_, numClasses_ + 2 , numClasses_);
111
+ matWGrad_->zeroMem ();
112
+ da_ = matWGrad_->subRowMatrix (0 , 1 );
113
+ db_ = matWGrad_->subRowMatrix (1 , 2 );
114
+ dw_ = matWGrad_->subRowMatrix (2 , numClasses_ + 2 );
115
+ }
117
116
118
117
real* alpha = alpha_->getData ();
119
118
real* beta = beta_->getData ();
120
119
real* expW = expW_->getData ();
121
120
real* expX = expX_->getData ();
122
- real* grad = matGrad ->getData ();
121
+ real* grad = matGrad_ ->getData ();
123
122
124
123
for (int i = 0 ; i < numClasses_; ++i) {
125
124
beta[(length - 1 ) * numClasses_ + i] = exp (b[i]);
@@ -140,39 +139,38 @@ void LinearChainCRF::backward(real* x, real* dx, int* s, int length) {
140
139
normalizeL1 (beta + k * numClasses_, numClasses_);
141
140
}
142
141
143
- matGrad ->dotMul (*alpha_, *beta_);
144
- matGrad ->rowNormalizeL1 (*matGrad );
142
+ matGrad_ ->dotMul (*alpha_, *beta_);
143
+ matGrad_ ->rowNormalizeL1 (*matGrad_ );
145
144
for (int k = 0 ; k < length; ++k) {
146
145
grad[k * numClasses_ + s[k]] -= (real)1 ;
147
146
}
148
- matDX->add (*matGrad);
149
- if (da_) {
150
- da_->add (*matGrad->subMatrix (/* startRow= */ 0 , /* numRows= */ 1 ));
151
- }
152
- if (db_) {
153
- db_->add (*matGrad->subMatrix (/* startRow= */ length - 1 , 1 ));
154
- }
155
147
156
- beta_->dotMul (*beta_, *expX_);
157
- beta_->rowNormalizeL1 (*beta_);
148
+ if (needWGrad) {
149
+ da_->add (*matGrad_->subMatrix (/* startRow= */ 0 , /* numRows= */ 1 ));
150
+ db_->add (*matGrad_->subMatrix (/* startRow= */ length - 1 , 1 ));
158
151
159
- for (int k = 1 ; dw && k < length; ++k) {
160
- real sum = 0 ;
161
- for (int i = 0 ; i < numClasses_; ++i) {
162
- for (int j = 0 ; j < numClasses_; ++j) {
163
- sum += expW[i * numClasses_ + j] * alpha[(k - 1 ) * numClasses_ + i] *
164
- beta[k * numClasses_ + j];
152
+ beta_->dotMul (*beta_, *expX_);
153
+ beta_->rowNormalizeL1 (*beta_);
154
+
155
+ real* dw = dw_->getData ();
156
+ for (int k = 1 ; k < length; ++k) {
157
+ real sum = 0 ;
158
+ for (int i = 0 ; i < numClasses_; ++i) {
159
+ for (int j = 0 ; j < numClasses_; ++j) {
160
+ sum += expW[i * numClasses_ + j] * alpha[(k - 1 ) * numClasses_ + i] *
161
+ beta[k * numClasses_ + j];
162
+ }
165
163
}
166
- }
167
- sum = 1 / sum;
168
- for (int i = 0 ; i < numClasses_; ++i ) {
169
- for ( int j = 0 ; j < numClasses_; ++j) {
170
- dw[i * numClasses_ + j] += sum * expW[i * numClasses_ + j ] *
171
- alpha[(k - 1 ) * numClasses_ + i] *
172
- beta[k * numClasses_ + j];
164
+ sum = 1 / sum;
165
+ for ( int i = 0 ; i < numClasses_; ++i) {
166
+ for (int j = 0 ; j < numClasses_; ++j ) {
167
+ dw[i * numClasses_ + j] += sum * expW[i * numClasses_ + j] *
168
+ alpha[(k - 1 ) * numClasses_ + i ] *
169
+ beta[k * numClasses_ + j];
170
+ }
173
171
}
172
+ dw[s[k - 1 ] * numClasses_ + s[k]] -= (real)1 ;
174
173
}
175
- dw[s[k - 1 ] * numClasses_ + s[k]] -= (real)1 ;
176
174
}
177
175
}
178
176
0 commit comments