Skip to content

Commit bce4f7d

Browse files
committed
follow comments.
1 parent 4c63086 commit bce4f7d

File tree

3 files changed

+34
-32
lines changed

3 files changed

+34
-32
lines changed

paddle/framework/tensor_impl.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,9 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
228228
PADDLE_ENFORCE_GE(begin_idx, 0,
229229
"The start row index must be greater than 0.");
230230
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound.");
231-
PADDLE_ENFORCE_LT(begin_idx, end_idx,
232-
"The start row index must be less than the end row index.");
231+
PADDLE_ENFORCE_LT(
232+
begin_idx, end_idx,
233+
"The start row index must be smaller than the end row index.");
233234

234235
if (dims_[0] == 1) {
235236
return *this;

paddle/operators/linear_chain_crf_op.cc

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,20 @@ T NormalizeL1(T* x, size_t len) {
2626
// Right now, we just bet that sum won't be zero. If this really happens, we
2727
// will figure out what should be done then.
2828
PADDLE_ENFORCE(sum,
29-
"The unnormalized probabilites of all possible unfinished "
29+
"The unnormalized probabilities of all possible unfinished "
3030
"sequences must be greater than 0.");
31-
for (size_t i = 0; i < len; ++i) x[i] /= sum;
31+
T s = 1. / sum;
32+
for (size_t i = 0; i < len; ++i) x[i] *= s;
3233
return sum;
3334
}
3435
} // namespace
3536

3637
using framework::LoDTensor;
3738
using framework::LoD;
3839

39-
class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
40+
class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker {
4041
public:
41-
LinearChainCrfOpMaker(framework::OpProto* proto,
42+
LinearChainCRFOpMaker(framework::OpProto* proto,
4243
framework::OpAttrChecker* op_checker)
4344
: OpProtoAndCheckerMaker(proto, op_checker) {
4445
AddInput(
@@ -51,11 +52,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
5152
AddInput(
5253
"Transition",
5354
"(Tensor, default: Tensor<float>). A Tensor with shape [(D + 2) x D]. "
54-
"The learnable parameter for linear_chain_crf operator. "
55+
"The learnable parameter for the linear_chain_crf operator. "
5556
"See more details in the operator's comments.");
5657
AddInput(
5758
"Label",
58-
"(LoDTensor, default: LoDTensor<int>). The ground truth which is a 2-D "
59+
"(LoDTensor, default: LoDTensor<int>). The groundtruth which is a 2-D "
5960
"LoDTensor with shape [N x 1], where N is the total element number in "
6061
"a mini-batch.");
6162
AddOutput(
@@ -82,14 +83,11 @@ class LinearChainCrfOpMaker : public framework::OpProtoAndCheckerMaker {
8283
.AsIntermediate();
8384
AddOutput(
8485
"LogLikelihood",
85-
"(Tensor, default: Tensor<float>). The logarithm of the "
86-
"conditional "
86+
"(Tensor, default: Tensor<float>). The logarithm of the conditional "
8787
"likelihood of each training sample in a mini-batch. This is a 2-D "
8888
"tensor with shape [S x 1], where S is the sequence number in a "
89-
"mini-batch. "
90-
"Note: S is equal to the sequence number in a mini-batch. The "
91-
"output "
92-
"is no longer a LoDTensor.");
89+
"mini-batch. Note: S is equal to the sequence number in a mini-batch. "
90+
"The output is no longer a LoDTensor.");
9391
AddComment(R"DOC(
9492
Conditional Random Field defines an undirected probabilistic graph with nodes
9593
denoting random variables and edges denoting dependencies between these
@@ -100,11 +98,11 @@ variables. CRF learns the conditional probability \f$P(Y|X)\f$, where
10098
Linear chain CRF is a special case of CRF that is useful for sequence labeling
10199
task. Sequence labeling tasks do not assume a lot of conditional
102100
independences among inputs. They only concern about the input and the output
103-
being linear sequences. Thus, the graph model of CRF is a simple chain or
104-
a line, which results in a linear chain CRF.
101+
being linear sequences. Thus, the graph model of such a CRF is a simple chain
102+
or a line, which results in the linear chain CRF.
105103
106-
This operator implements the Forward-Backward algorithm for linear chain CRF.
107-
Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
104+
This operator implements the Forward-Backward algorithm for the linear chain
105+
CRF. Please see http://www.cs.columbia.edu/~mcollins/fb.pdf for reference.
108106
109107
Equation:
110108
@@ -144,7 +142,7 @@ nonlinear activation.
144142
}
145143
};
146144

147-
class LinearChainCrfOp : public framework::OperatorWithKernel {
145+
class LinearChainCRFOp : public framework::OperatorWithKernel {
148146
public:
149147
using framework::OperatorWithKernel::OperatorWithKernel;
150148

@@ -211,7 +209,7 @@ class LinearChainCrfOp : public framework::OperatorWithKernel {
211209
};
212210

213211
template <typename T>
214-
class LinearChainCrfOpKernel<platform::CPUPlace, T>
212+
class LinearChainCRFOpKernel<platform::CPUPlace, T>
215213
: public framework::OpKernel<T> {
216214
public:
217215
void Compute(const framework::ExecutionContext& ctx) const override {
@@ -262,11 +260,11 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
262260
w_exps.device(place) = w.exp();
263261

264262
auto* alpha = ctx.Output<LoDTensor>("Alpha");
265-
alpha->mutable_data<T>(ctx.GetPlace());
263+
alpha->mutable_data<T>(platform::CPUPlace());
266264
auto* ll = ctx.Output<LoDTensor>("LogLikelihood");
267265
// resize the output tensor to the correct dimension.
268266
ll->Resize({static_cast<int>(seq_num), 1});
269-
T* log_likelihood = ll->mutable_data<T>(ctx.GetPlace());
267+
T* log_likelihood = ll->mutable_data<T>(platform::CPUPlace());
270268
for (size_t i = 0; i < seq_num; ++i) {
271269
int start_pos = static_cast<int>(in_lod[level][i]);
272270
int end_pos = static_cast<int>(in_lod[level][i + 1]);
@@ -322,6 +320,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
322320
}
323321
alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum;
324322
}
323+
// NormalizeL1 is to avoid underflow or overflow at (*).
325324
ll -= x_row_max[k] +
326325
std::log(NormalizeL1<T>(alpha_value + k * tag_num, tag_num));
327326
}
@@ -330,6 +329,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
330329
sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i];
331330
}
332331
ll -= std::log(sum);
332+
// Now ll is equal to -log(Z).
333333

334334
const int* lbl = label->data<int>();
335335
PADDLE_ENFORCE_LT(
@@ -347,7 +347,7 @@ class LinearChainCrfOpKernel<platform::CPUPlace, T>
347347
}
348348
};
349349

350-
class LinearChainCrfGradOp : public framework::OperatorWithKernel {
350+
class LinearChainCRFGradOp : public framework::OperatorWithKernel {
351351
public:
352352
using framework::OperatorWithKernel::OperatorWithKernel;
353353

@@ -407,11 +407,11 @@ class LinearChainCrfGradOp : public framework::OperatorWithKernel {
407407
};
408408

409409
template <typename T>
410-
class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
410+
class LinearChainCRFGradOpKernel<platform::CPUPlace, T>
411411
: public framework::OpKernel<T> {
412412
public:
413413
void Compute(const framework::ExecutionContext& ctx) const override {
414-
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
414+
PADDLE_ENFORCE(platform::is_cpu_place(platform::CPUPlace()),
415415
"This kernel only runs on CPU.");
416416
auto* label = ctx.Input<LoDTensor>("Label");
417417
auto* emission_exps = ctx.Input<LoDTensor>("EmissionExps");
@@ -493,6 +493,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
493493
}
494494
beta_value[k * tag_num + i] = sum;
495495
}
496+
// NormalizeL1 is to avoid underflow or overflow at (**).
496497
NormalizeL1<T>(beta_value + k * tag_num, tag_num);
497498
}
498499

@@ -534,7 +535,7 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
534535
T sum = 0.;
535536
for (size_t i = 0; i < tag_num; ++i) {
536537
for (size_t j = 0; j < tag_num; ++j) {
537-
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] *
538+
sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**)
538539
alpha_mat(k - 1, i) * tmp_mat(k, j);
539540
}
540541
}
@@ -557,11 +558,11 @@ class LinearChainCrfGradOpKernel<platform::CPUPlace, T>
557558
} // namespace paddle
558559

559560
namespace ops = paddle::operators;
560-
REGISTER_OP(linear_chain_crf, ops::LinearChainCrfOp, ops::LinearChainCrfOpMaker,
561-
linear_chain_crf_grad, ops::LinearChainCrfGradOp);
561+
REGISTER_OP(linear_chain_crf, ops::LinearChainCRFOp, ops::LinearChainCRFOpMaker,
562+
linear_chain_crf_grad, ops::LinearChainCRFGradOp);
562563
REGISTER_OP_CPU_KERNEL(
563564
linear_chain_crf,
564-
ops::LinearChainCrfOpKernel<paddle::platform::CPUPlace, float>);
565+
ops::LinearChainCRFOpKernel<paddle::platform::CPUPlace, float>);
565566
REGISTER_OP_CPU_KERNEL(
566567
linear_chain_crf_grad,
567-
ops::LinearChainCrfGradOpKernel<paddle::platform::CPUPlace, float>);
568+
ops::LinearChainCRFGradOpKernel<paddle::platform::CPUPlace, float>);

paddle/operators/linear_chain_crf_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
2525
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
2626

2727
template <typename Place, typename T>
28-
class LinearChainCrfOpKernel : public framework::OpKernel<T> {
28+
class LinearChainCRFOpKernel : public framework::OpKernel<T> {
2929
public:
3030
void Compute(const framework::ExecutionContext& ctx) const override;
3131

@@ -37,7 +37,7 @@ class LinearChainCrfOpKernel : public framework::OpKernel<T> {
3737
};
3838

3939
template <typename Place, typename T>
40-
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
40+
class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
4141
public:
4242
void Compute(const framework::ExecutionContext& ctx) const override;
4343

0 commit comments

Comments
 (0)