Skip to content

Commit cca383c

Browse files
committed
follow comments.
1 parent 3afb9dc commit cca383c

File tree

2 files changed

+295
-326
lines changed

2 files changed

+295
-326
lines changed

paddle/operators/linear_chain_crf_op.cc

Lines changed: 8 additions & 316 deletions
Original file line numberDiff line numberDiff line change
@@ -17,26 +17,6 @@ limitations under the License. */
1717
namespace paddle {
1818
namespace operators {
1919

20-
namespace {
21-
template <typename T>
22-
T NormalizeL1(T* x, size_t len) {
23-
T sum = 0.;
24-
for (size_t i = 0; i < len; ++i) sum += x[i];
25-
// (This comment is from the old LinearChainCRFLayer.)
26-
// Right now, we just bet that sum won't be zero. If this really happens, we
27-
// will figure out what should be done then.
28-
PADDLE_ENFORCE(sum,
29-
"The unnormalized probabilities of all possible unfinished "
30-
"sequences must be greater than 0.");
31-
T s = 1. / sum;
32-
for (size_t i = 0; i < len; ++i) x[i] *= s;
33-
return sum;
34-
}
35-
} // namespace
36-
37-
using framework::LoDTensor;
38-
using framework::LoD;
39-
4020
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
4121
public:
4222
LinearChainCRFOpMaker(framework::OpProto* proto,
@@ -206,145 +186,6 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
206186
}
207187
};
208188

209-
template <typename T>
210-
class LinearChainCRFOpKernel<platform::CPUPlace, T>
211-
: public framework::OpKernel<T> {
212-
public:
213-
void Compute(const framework::ExecutionContext& ctx) const override {
214-
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
215-
"This kernel only runs on CPU.");
216-
auto* emission_weights = ctx.Input<LoDTensor>("Emission");
217-
auto* transition_weights = ctx.Input<Tensor>("Transition");
218-
auto* emission_exps = ctx.Output<LoDTensor>("EmissionExps");
219-
emission_exps->mutable_data<T>(platform::CPUPlace());
220-
auto* transition_exps = ctx.Output<Tensor>("TransitionExps");
221-
transition_exps->mutable_data<T>(platform::CPUPlace());
222-
auto* label = ctx.Input<LoDTensor>("Label");
223-
224-
auto in_lod = emission_weights->lod();
225-
PADDLE_ENFORCE(in_lod.size(), "Input(Emission) is not a sequence.");
226-
227-
// TODO(caoying) The checks related to LoD information should be
228-
// moved into InferShape once after the InferShape is refactored.
229-
PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL,
230-
"The Input(Emission) should be a sequence.");
231-
PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL,
232-
"The Input(Label) should be a sequence.");
233-
const size_t level = 0;
234-
235-
auto emission_dims = emission_weights->dims();
236-
const size_t batch_size = emission_dims[0];
237-
const size_t tag_num = emission_dims[1];
238-
const size_t seq_num = in_lod[level].size() - 1;
239-
240-
Tensor emission_row_max;
241-
emission_row_max.mutable_data<T>(
242-
framework::make_ddim({static_cast<int>(batch_size), 1}),
243-
platform::CPUPlace());
244-
245-
auto place = ctx.GetEigenDevice<platform::CPUPlace>();
246-
auto x = EigenMatrix<T>::From(*emission_weights);
247-
auto x_row_max = EigenMatrix<T>::From(emission_row_max);
248-
x_row_max.device(place) =
249-
x.maximum(Eigen::DSizes<int, 1>(1))
250-
.reshape(Eigen::DSizes<int, 2>(int(batch_size), 1));
251-
252-
auto x_exps = EigenMatrix<T>::From(*emission_exps);
253-
x_exps.device(place) =
254-
(x - x_row_max.broadcast(Eigen::DSizes<int, 2>(1, tag_num))).exp();
255-
256-
auto w = EigenMatrix<T>::From(*transition_weights);
257-
auto w_exps = EigenMatrix<T>::From(*transition_exps);
258-
w_exps.device(place) = w.exp();
259-
260-
auto* alpha = ctx.Output<LoDTensor>("Alpha");
261-
alpha->mutable_data<T>(platform::CPUPlace());
262-
auto* ll = ctx.Output<LoDTensor>("LogLikelihood");
263-
// resize the output tensor to the correct dimension.
264-
ll->Resize({static_cast<int>(seq_num), 1});
265-
T* log_likelihood = ll->mutable_data<T>(platform::CPUPlace());
266-
for (size_t i = 0; i < seq_num; ++i) {
267-
int start_pos = static_cast<int>(in_lod[level][i]);
268-
int end_pos = static_cast<int>(in_lod[level][i + 1]);
269-
if (end_pos == start_pos) {
270-
// If an empty input sequence is given, pad 0 for its cost.
271-
log_likelihood[i] = 0.;
272-
continue;
273-
}
274-
275-
const Tensor one_seq = emission_weights->Slice(start_pos, end_pos);
276-
Tensor one_seq_row_max = emission_row_max.Slice(start_pos, end_pos);
277-
Tensor one_seq_exps = emission_exps->Slice(start_pos, end_pos);
278-
const Tensor one_seq_label = label->Slice(start_pos, end_pos);
279-
Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos);
280-
281-
log_likelihood[i] = ForwardOneSequence(
282-
&one_seq, &one_seq_row_max, &one_seq_exps, transition_weights,
283-
transition_exps, &one_seq_label, &one_seq_alpha);
284-
}
285-
}
286-
287-
protected:
288-
T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max,
289-
const Tensor* emission_exps, const Tensor* trans_weights,
290-
const Tensor* trans_weight_exps, const Tensor* label,
291-
Tensor* alpha) const {
292-
const T* x = emission->data<T>();
293-
const T* x_row_max = emission_row_max->data<T>();
294-
const T* x_exps = emission_exps->data<T>();
295-
const T* w = trans_weights->data<T>();
296-
const T* w_exps = trans_weight_exps->data<T>();
297-
T* alpha_value = alpha->data<T>();
298-
299-
auto x_dims = emission->dims();
300-
const size_t seq_length = x_dims[0];
301-
const size_t tag_num = x_dims[1];
302-
// The 1st row of w are transition weights for start mask.
303-
// The 2nd row of w are transition weights for end mask.
304-
// Transition weights among other tags begin from the 3rd row of w.
305-
const size_t state_trans_base_idx = 2;
306-
307-
for (size_t i = 0; i < tag_num; ++i) {
308-
alpha_value[i] = w_exps[i] * x_exps[i];
309-
}
310-
T ll = -x_row_max[0] - std::log(NormalizeL1<T>(alpha_value, tag_num));
311-
312-
for (size_t k = 1; k < seq_length; ++k) {
313-
for (size_t i = 0; i < tag_num; ++i) {
314-
T sum = 0.;
315-
for (size_t j = 0; j < tag_num; ++j) {
316-
sum += alpha_value[(k - 1) * tag_num + j] *
317-
w_exps[(j + state_trans_base_idx) * tag_num + i];
318-
}
319-
alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
320-
}
321-
// NormalizeL1 is to avoid underflow or overflow at (*).
322-
ll -= x_row_max[k] +
323-
std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
324-
}
325-
T sum = 0.;
326-
for (size_t i = 0; i < tag_num; ++i) {
327-
sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
328-
}
329-
ll -= std::log(sum);
330-
// Now ll is equal to -log(Z).
331-
332-
const int* lbl = label->data<int>();
333-
PADDLE_ENFORCE_LT(
334-
*std::max_element(lbl, lbl + seq_length), tag_num,
335-
"An invalid tag label that execesses the largest tag number.");
336-
337-
// Calculate the nominator part, which depends on the label sequence.
338-
ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] +
339-
w[tag_num + lbl[seq_length - 1]] /*end transition*/;
340-
for (size_t k = 1; k < seq_length; ++k) {
341-
ll += x[k * tag_num + lbl[k]] +
342-
w[(lbl[k - 1] + state_trans_base_idx) * tag_num + lbl[k]];
343-
}
344-
return -ll;
345-
}
346-
};
347-
348189
class LinearChainCRFGradOp : public framework::OperatorWithKernel {
349190
public:
350191
using framework::OperatorWithKernel::OperatorWithKernel;
@@ -357,11 +198,6 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
357198
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")),
358199
"Input(LogLikelihood@GRAD) shoudl be not null.");
359200

360-
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Emission")),
361-
"Output(Emission@GRAD) should be not null.");
362-
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Transition")),
363-
"Output(Transition@GRAD) should be not null.");
364-
365201
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
366202
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2UL,
367203
"The Input(EmissionExps) should be a 2-D tensor.");
@@ -390,168 +226,24 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
390226
"The height of Input(EmissionExps) and the height of Input(Label) "
391227
"should be the same.");
392228

393-
ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
394-
ctx->SetOutputDim(framework::GradVarName("Transition"),
395-
transition_exps_dims);
229+
if (ctx->HasOutput(framework::GradVarName("Emission"))) {
230+
ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
231+
}
232+
if (ctx->HasOutput(framework::GradVarName("Transition"))) {
233+
ctx->SetOutputDim(framework::GradVarName("Transition"),
234+
transition_exps_dims);
235+
}
396236
}
397237

398238
protected:
399239
// Explicitly set that the data type of output of the linear_chain_crf_grad
400-
// operator is determined by its input "EmissionExps".
240+
// operator is determined by its input: graidents of LogLikelihood.
401241
framework::DataType IndicateDataType(
402242
const framework::ExecutionContext& ctx) const override {
403243
return framework::ToDataType(ctx.Input<LoDTensor>("LogLikelihood")->type());
404244
}
405245
};
406246

407-
template <typename T>
408-
class LinearChainCRFGradOpKernel<platform::CPUPlace, T>
409-
: public framework::OpKernel<T> {
410-
public:
411-
void Compute(const framework::ExecutionContext& ctx) const override {
412-
PADDLE_ENFORCE(platform::is_cpu_place(platform::CPUPlace()),
413-
"This kernel only runs on CPU.");
414-
auto* label = ctx.Input<LoDTensor>("Label");
415-
auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps");
416-
auto* transition_exps = ctx.Input<Tensor>("TransitionExps");
417-
auto* alpha = ctx.Input<LoDTensor>("Alpha");
418-
const T* ll_grad =
419-
ctx.Input<Tensor>(framework::GradVarName("LogLikelihood"))->data<T>();
420-
421-
auto* emission_grad =
422-
ctx.Output<Tensor>(framework::GradVarName("Emission"));
423-
emission_grad->mutable_data<T>(platform::CPUPlace());
424-
425-
auto* trans_grad = ctx.Output<Tensor>(framework::GradVarName("Transition"));
426-
if (trans_grad) trans_grad->mutable_data<T>(platform::CPUPlace());
427-
428-
auto emission_dims = emission_exps->dims();
429-
430-
// Beta is the memo table used in dynamic programming to calculate the
431-
// backwark vectors. For a backward vector i (the i-th row of beta), it
432-
// captures the unnormalized probabilities of partial sequences starting at
433-
// position i.
434-
Tensor beta;
435-
beta.mutable_data<T>(emission_dims, platform::CPUPlace());
436-
437-
const size_t level = 0; // currently, only support sequence.
438-
auto lod = label->lod();
439-
PADDLE_ENFORCE(lod.size(), "Input(Label) is not a sequence.");
440-
441-
for (size_t i = 0; i < lod[level].size() - 1; ++i) {
442-
int start_pos = static_cast<int>(lod[level][i]);
443-
int end_pos = static_cast<int>(lod[level][i + 1]);
444-
if (end_pos == start_pos) continue;
445-
446-
const Tensor one_seq_emission_exps =
447-
emission_exps->Slice(start_pos, end_pos);
448-
const Tensor one_seq_label = label->Slice(start_pos, end_pos);
449-
const Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos);
450-
Tensor one_seq_beta = beta.Slice(start_pos, end_pos);
451-
Tensor one_seq_emission_grad = emission_grad->Slice(start_pos, end_pos);
452-
453-
BackwardOneSequence(ctx.device_context(), ll_grad[i],
454-
&one_seq_emission_exps, transition_exps,
455-
&one_seq_alpha, &one_seq_label, &one_seq_beta,
456-
trans_grad, &one_seq_emission_grad);
457-
}
458-
}
459-
460-
protected:
461-
void BackwardOneSequence(const platform::DeviceContext& ctx, const T ll_grad,
462-
const Tensor* emission_exps,
463-
const Tensor* transition_exps, const Tensor* alpha,
464-
const Tensor* label, Tensor* beta,
465-
Tensor* transition_grad,
466-
Tensor* emission_grad) const {
467-
const T* w_exps = transition_exps->data<T>();
468-
const T* x_exps = emission_exps->data<T>();
469-
const int* label_value = label->data<int>();
470-
T* beta_value = beta->data<T>();
471-
472-
auto x_dims = emission_exps->dims();
473-
const size_t seq_length = x_dims[0];
474-
const size_t tag_num = x_dims[1];
475-
const size_t state_trans_base_idx = 2;
476-
477-
// Calculate the backward vectors: beta.
478-
// First, calculate the initialition state.
479-
for (size_t i = 0; i < tag_num; ++i) {
480-
beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i];
481-
}
482-
NormalizeL1<T>(beta_value + (seq_length - 1) * tag_num, tag_num);
483-
484-
for (int k = static_cast<int>(seq_length) - 2; k >= 0; --k) {
485-
for (size_t i = 0; i < tag_num; ++i) {
486-
T sum = 0.;
487-
for (size_t j = 0; j < tag_num; ++j) {
488-
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
489-
x_exps[(k + 1) * tag_num + j] *
490-
beta_value[(k + 1) * tag_num + j];
491-
}
492-
beta_value[k * tag_num + i] = sum;
493-
}
494-
// NormalizeL1 is to avoid underflow or overflow at (**).
495-
NormalizeL1<T>(beta_value + k * tag_num, tag_num);
496-
}
497-
498-
auto alpha_mat = EigenMatrix<T>::From(*alpha);
499-
auto beta_mat = EigenMatrix<T>::From(*beta);
500-
auto x_grad_mat = EigenMatrix<T>::From(*emission_grad);
501-
auto* place = ctx.GetEigenDevice<platform::CPUPlace>();
502-
auto prob = alpha_mat * beta_mat;
503-
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
504-
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
505-
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
506-
x_grad_mat.device(*place) = prob / row_sum;
507-
508-
for (size_t k = 0; k < seq_length; ++k) {
509-
x_grad_mat(k, label_value[k]) -= static_cast<T>(1.);
510-
}
511-
512-
if (transition_grad) {
513-
T* trans_grad = transition_grad->data<T>();
514-
for (size_t k = 0; k < tag_num; ++k) {
515-
trans_grad[k] += x_grad_mat(/*from start state*/ 0, k);
516-
trans_grad[tag_num + k] +=
517-
x_grad_mat(/*to end state*/ seq_length - 1, k);
518-
}
519-
520-
auto x_exps_mat = EigenMatrix<T>::From(*emission_exps);
521-
522-
// TODO(caoying): Fix this to avoid using this local variable.
523-
Tensor tmp;
524-
tmp.mutable_data<T>(beta->dims(), platform::CPUPlace());
525-
auto tmp_mat = EigenMatrix<T>::From(tmp);
526-
auto prob = beta_mat * x_exps_mat;
527-
auto row_sum = prob.sum(Eigen::DSizes<int, 1>(1))
528-
.reshape(Eigen::DSizes<int, 2>(seq_length, 1))
529-
.broadcast(Eigen::DSizes<int, 2>(1, tag_num));
530-
tmp_mat.device(*place) = prob / row_sum;
531-
532-
for (size_t k = 1; k < seq_length; ++k) {
533-
T sum = 0.;
534-
for (size_t i = 0; i < tag_num; ++i) {
535-
for (size_t j = 0; j < tag_num; ++j) {
536-
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**)
537-
alpha_mat(k - 1, i) * tmp_mat(k, j);
538-
}
539-
}
540-
sum = 1. / sum;
541-
for (size_t i = 0; i < tag_num; ++i) {
542-
for (size_t j = 0; j < tag_num; ++j) {
543-
trans_grad[(i + state_trans_base_idx) * tag_num + j] +=
544-
sum * w_exps[(i + state_trans_base_idx) * tag_num + j] *
545-
alpha_mat(k - 1, i) * tmp_mat(k, j);
546-
}
547-
}
548-
trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num +
549-
label_value[k]] -= static_cast<T>(1.);
550-
}
551-
}
552-
}
553-
};
554-
555247
} // namespace operators
556248
} // namespace paddle
557249

0 commit comments

Comments
 (0)