@@ -17,6 +17,9 @@ limitations under the License. */
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
+ using framework::LoDTensor;
21
+ using framework::LoD;
22
+
20
23
class LinearChainCrfOpMaker : public framework ::OpProtoAndCheckerMaker {
21
24
public:
22
25
LinearChainCrfOpMaker (framework::OpProto* proto,
@@ -77,14 +80,14 @@ Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
77
80
78
81
Equation:
79
82
80
- - Denote the first input of this operator (Emission) as \f$x\f$ here.
81
- - The first D values of the second input (Transition) of this operator are for
82
- starting weights, denoted as \f$a\f$ here.
83
- - The next D values of the second input (Transition) of this operator are for
84
- ending weights, denoted as \f$b\f$ here.
85
- - The remaning values of the second input (Transition) are for transition
86
- weights, denoted as \f$w\f$ here.
87
- - Denote the third input of this operator (Label) as \f$s\f$ here.
83
+ - Denote Input(Emission) to this operator as \f$x\f$ here.
84
+ - The first D values of Input (Transition) to this operator are for starting
85
+ weights, denoted as \f$a\f$ here.
86
+ - The next D values of Input (Transition) of this operator are for ending
87
+ weights, denoted as \f$b\f$ here.
88
+ - The remaning values of Input (Transition) are for transition weights,
89
+ denoted as \f$w\f$ here.
90
+ - Denote Input (Label) as \f$s\f$ here.
88
91
89
92
The probability of a sequence \f$s\f$ of length \f$L\f$ is defined as:
90
93
\f$P(s) = (1/Z) exp(a_{s_1} + b_{s_L}
@@ -107,8 +110,7 @@ sequences internally, it expects UNSCALED emission feature weights.
107
110
Please do not call this op with the emission feature being output of any
108
111
nonlinear activation.
109
112
110
- 3. The 2nd dimension of the first input of this operator (Emission) MUST be
111
- equal to the tag number.
113
+ 3. The 2nd dimension of Input(Emission) MUST be equal to the tag number.
112
114
113
115
)DOC" );
114
116
}
@@ -136,33 +138,188 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
136
138
auto label_dims = ctx->GetInputDim (" Label" );
137
139
138
140
PADDLE_ENFORCE_EQ (emission_dims.size (), 2UL ,
139
- " The input Emission should be a 2-D tensor." );
141
+ " The Input( Emission) should be a 2-D tensor." );
140
142
PADDLE_ENFORCE_EQ (transition_dims.size (), 2UL ,
141
- " The input Transition should be a 2-D tensor." );
143
+ " The Input( Transition) should be a 2-D tensor." );
142
144
PADDLE_ENFORCE_EQ (
143
- transition_dims[0 ] + 2 , transition_dims[1 ],
144
- " An invalid dimension for the input Transition, which should "
145
+ transition_dims[0 ] - 2 , transition_dims[1 ],
146
+ " An invalid dimension for the Input( Transition) , which should "
145
147
" be a 2-D tensor with shape [D + 2 x D]." );
146
148
PADDLE_ENFORCE_EQ (
147
149
emission_dims[1 ], transition_dims[1 ],
148
- " The 2nd dimension of the input Emission and the input Transition "
150
+ " The 2nd dimension of the Input( Emission) and the Input( Transition) "
149
151
" should be equal to the tag number." );
150
152
PADDLE_ENFORCE (label_dims.size () == 2UL && label_dims[1 ] == 1UL ,
151
- " The input Label should be a 2-D tensor "
152
- " with the 2nd dimensions fixed to 1." );
153
+ " The Input(Label) should be a 2-D tensor with the 2nd "
154
+ " dimensions fixed to 1." );
155
+ PADDLE_ENFORCE_EQ (
156
+ emission_dims[0 ], label_dims[0 ],
157
+ " The height of Input(Emission) and the height of Input(Label) "
158
+ " should be the same." );
153
159
154
160
ctx->SetOutputDim (" Alpha" , emission_dims);
161
+
162
+ // (TODO caoying) This is tricky. The 1st dimension of Output(LogLikelihood)
163
+ // is the sequence number in a mini-batch. The dimension set here should be
164
+ // resized to its correct size in the function Compute.
155
165
ctx->SetOutputDim (" LogLikelihood" , {emission_dims[0 ], 1 });
156
166
}
157
167
158
- // Explicitly set data type of output of the linear_chain_crf operator
159
- // is determined by its input "Emission".
168
+ // Explicitly set that the data type of output of the linear_chain_crf
169
+ // operator is determined by its input "Emission".
160
170
framework::DataType IndicateDataType (
161
171
const framework::ExecutionContext& ctx) const override {
162
172
return framework::ToDataType (ctx.Input <Tensor>(" Emission" )->type ());
163
173
}
164
174
};
165
175
176
+ template <typename T>
177
+ class LinearChainCrfOpKernel <platform::CPUPlace, T>
178
+ : public framework::OpKernel<T> {
179
+ public:
180
+ void Compute (const framework::ExecutionContext& ctx) const override {
181
+ PADDLE_ENFORCE (platform::is_cpu_place (ctx.GetPlace ()),
182
+ " This kernel only runs on CPU." );
183
+
184
+ auto * emission_weights = ctx.Input <LoDTensor>(" Emission" );
185
+ auto * transition_weights = ctx.Input <Tensor>(" Transition" );
186
+ auto * label = ctx.Input <LoDTensor>(" Label" );
187
+
188
+ auto in_lod = emission_weights->lod ();
189
+ // TODO(caoying) The checks related to LoD information should be
190
+ // moved into InferShape once after the InferShape is refactored.
191
+ PADDLE_ENFORCE_EQ (emission_weights->NumLevels (), 1UL ,
192
+ " The Input(Emission) should be a sequence." );
193
+ PADDLE_ENFORCE_EQ (label->NumLevels (), 1UL ,
194
+ " The Input(Label) should be a sequence." );
195
+ const size_t level = 0 ;
196
+
197
+ auto emission_dims = emission_weights->dims ();
198
+ const size_t seq_num = in_lod[level].size () - 1 ;
199
+
200
+ // TODO(caoying) These local variables seems to be created and destroied
201
+ // every time this function is called. Will this bring additional overhead?
202
+ Tensor emission_exps;
203
+ Tensor emission_row_max;
204
+ Tensor transition_exps;
205
+ emission_exps.mutable_data <T>(emission_dims, platform::CPUPlace ());
206
+ emission_row_max.mutable_data <T>(
207
+ framework::make_ddim ({emission_dims[0 ], 1 }), platform::CPUPlace ());
208
+ transition_exps.mutable_data <T>(transition_weights->dims (),
209
+ platform::CPUPlace ());
210
+
211
+ auto * alpha = ctx.Output <Tensor>(" Alpha" );
212
+ alpha->mutable_data <T>(ctx.GetPlace ());
213
+ auto * ll = ctx.Output <Tensor>(" LogLikelihood" );
214
+ // resize the output tensor to the correct dimension.
215
+ ll->Resize ({static_cast <int >(seq_num), 1 });
216
+ T* log_likelihood = ll->mutable_data <T>(ctx.GetPlace ());
217
+
218
+ for (size_t i = 0 ; i < seq_num; ++i) {
219
+ int start_pos = static_cast <int >(in_lod[level][i]);
220
+ int end_pos = static_cast <int >(in_lod[level][i + 1 ]);
221
+
222
+ const Tensor one_seq = emission_weights->Slice <T>(start_pos, end_pos);
223
+ Tensor one_seq_row_max = emission_row_max.Slice <T>(start_pos, end_pos);
224
+ Tensor one_seq_exps = emission_exps.Slice <T>(start_pos, end_pos);
225
+ const Tensor one_seq_label = label->Slice <T>(start_pos, end_pos);
226
+ Tensor one_seq_alpha = alpha->Slice <T>(start_pos, end_pos);
227
+
228
+ log_likelihood[i] = ForwardOneSequence (
229
+ ctx.device_context (), one_seq, one_seq_row_max, one_seq_exps,
230
+ (*transition_weights), transition_exps, one_seq_label, one_seq_alpha);
231
+ }
232
+ }
233
+
234
+ protected:
235
+ T ForwardOneSequence (const platform::DeviceContext& ctx,
236
+ const Tensor& emission, Tensor& emission_row_max,
237
+ Tensor& emission_exps, const Tensor& trans_weights,
238
+ Tensor& trans_weight_exps, const Tensor& label,
239
+ Tensor& alpha) const {
240
+ // (TODO caoying) Evaluate and optimize this.
241
+ // The Eigen compution kernel will be invoked for multiple times.
242
+ // Some computations regardless of sequence inforamtion could be performed
243
+ // only one time for the entire batch. This potentially could be optimized.
244
+
245
+ auto x_dims = emission.dims ();
246
+ const size_t seq_length = x_dims[0 ];
247
+ const size_t tag_num = x_dims[1 ];
248
+
249
+ T* alpha_value = alpha.data <T>();
250
+
251
+ auto x = EigenMatrix<T>::From (emission);
252
+ auto x_row_max = EigenMatrix<T>::From (emission_row_max);
253
+ const int class_dim = 1 ;
254
+ x_row_max.device (*ctx.GetEigenDevice <platform::CPUPlace>()) =
255
+ x.maximum (Eigen::DSizes<int , 1 >(class_dim))
256
+ .reshape (Eigen::DSizes<int , 2 >(int (seq_length), 1 ));
257
+
258
+ auto x_exps = EigenMatrix<T>::From (emission_exps);
259
+ x_exps.device (*ctx.GetEigenDevice <platform::CPUPlace>()) =
260
+ (x - x_row_max.broadcast (Eigen::DSizes<int , 2 >(1 , tag_num))).exp ();
261
+
262
+ auto w = EigenMatrix<T>::From (trans_weights);
263
+ auto w_exps = EigenMatrix<T>::From (trans_weight_exps);
264
+ w_exps.device (*ctx.GetEigenDevice <platform::CPUPlace>()) = w.exp ();
265
+ // The 1st row of w are transition weights for start mask.
266
+ const size_t start_ridx = 0 ;
267
+ // The 2nd row of w are transition weights for end mask.
268
+ const size_t end_ridx = 1 ;
269
+ // Transition weights among other tags begins from the 3rd row of w.
270
+ const size_t state_base_ridx = 2 ;
271
+
272
+ for (size_t i = 0 ; i < tag_num; ++i) {
273
+ alpha_value[i] = w_exps (start_ridx, i) * x_exps (0 , i);
274
+ }
275
+ T ll = -x_row_max (0 , 1 ) - std::log (NormalizeL1 (alpha_value, tag_num));
276
+
277
+ for (size_t k = 1 ; k < seq_length; ++k) {
278
+ for (size_t i = 0 ; i < tag_num; ++i) {
279
+ T sum = 0 .;
280
+ for (size_t j = 0 ; j < tag_num; ++j) {
281
+ sum += alpha_value[(k - 1 ) * tag_num + j] *
282
+ w_exps (j + state_base_ridx, i);
283
+ }
284
+ alpha_value[k * tag_num + i] = x_exps (k, i) * sum;
285
+ }
286
+ ll -= x_row_max (k, 1 ) +
287
+ std::log (NormalizeL1 (alpha_value + k * tag_num, tag_num));
288
+ }
289
+ T sum = 0 .;
290
+ for (size_t i = 0 ; i < tag_num; ++i) {
291
+ sum += alpha_value[(seq_length - 1 ) * tag_num + i] * w_exps (end_ridx, i);
292
+ }
293
+ ll -= std::log (sum);
294
+
295
+ const int * lbl = label.data <int >();
296
+ PADDLE_ENFORCE_LT (
297
+ *std::max_element (lbl, lbl + seq_length), tag_num,
298
+ " An invalid tag label that execesses the largest tag number." );
299
+
300
+ // Calculate the nominator part, which depends on the label sequence.
301
+ ll += w (start_ridx, lbl[0 ]) + x (start_ridx, lbl[0 ]) +
302
+ w (end_ridx, lbl[seq_length - 1 ]);
303
+ for (size_t k = 1 ; k < seq_length; ++k)
304
+ ll += x (k, lbl[k]) + w (lbl[k - 1 ], lbl[k]);
305
+ return -ll;
306
+ }
307
+
308
+ private:
309
+ T NormalizeL1 (T* x, size_t len) const {
310
+ T sum = 0 .;
311
+ for (size_t i = 0 ; i < len; ++i) sum += x[i];
312
+ // (This comment is from the old LinearChainCRFLayer.)
313
+ // Right now, we just bet that sum won't be zero. If this really happens, we
314
+ // will figure out what should be done then.
315
+ PADDLE_ENFORCE (sum,
316
+ " The unnormalized probabilites of all possible unfinished "
317
+ " sequences must be greater than 0." );
318
+ for (size_t i = 0 ; i < len; ++i) x[i] /= sum;
319
+ return sum;
320
+ }
321
+ };
322
+
166
323
class LinearChainCrfGradOp : public framework ::OperatorWithKernel {
167
324
public:
168
325
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -171,12 +328,25 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
171
328
void InferShape (framework::InferShapeContext* ctx) const override {}
172
329
};
173
330
331
+ template <typename T>
332
+ class LinearChainCrfGradOpKernel <platform::CPUPlace, T>
333
+ : public framework::OpKernel<T> {
334
+ public:
335
+ void Compute (const framework::ExecutionContext& ctx) const override {
336
+ PADDLE_ENFORCE (platform::is_cpu_place (ctx.GetPlace ()),
337
+ " This kernel only runs on CPU." );
338
+ }
339
+ };
340
+
174
341
} // namespace operators
175
342
} // namespace paddle
176
343
177
344
namespace ops = paddle::operators;
178
345
REGISTER_OP (linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
179
346
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
180
- REGISTER_OP_CPU_KERNEL (linear_chain_crf, ops::LinearChainCrfOpKernel<float >);
181
- REGISTER_OP_CPU_KERNEL (linear_chain_crf_grad,
182
- ops::LinearChainCrfGradOpKernel<float >);
347
+ REGISTER_OP_CPU_KERNEL (
348
+ linear_chain_crf,
349
+ ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float >);
350
+ REGISTER_OP_CPU_KERNEL (
351
+ linear_chain_crf_grad,
352
+ ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float >);
0 commit comments