Skip to content

Commit 8401039

Browse files
authored
Merge pull request #5084 from lcy-seso/crf
Add the LinearChainCrf operator.
2 parents d3b07a6 + ebd992e commit 8401039

File tree

12 files changed

+1020
-35
lines changed

12 files changed

+1020
-35
lines changed

paddle/framework/operator.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,32 +37,32 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
3737
std::string OperatorBase::Input(const std::string& name) const {
3838
auto& ins = Inputs(name);
3939
PADDLE_ENFORCE_LE(ins.size(), 1UL,
40-
"Op %s input %s should contain only one variable", type_,
41-
name);
40+
"Operator %s's input %s should contain only one variable.",
41+
type_, name);
4242
return ins.empty() ? kEmptyVarName : ins[0];
4343
}
4444

4545
const std::vector<std::string>& OperatorBase::Inputs(
4646
const std::string& name) const {
4747
auto it = inputs_.find(name);
48-
PADDLE_ENFORCE(it != inputs_.end(), "Op %s do not have input %s", type_,
49-
name);
48+
PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.",
49+
type_, name);
5050
return it->second;
5151
}
5252

5353
std::string OperatorBase::Output(const std::string& name) const {
5454
auto& outs = Outputs(name);
5555
PADDLE_ENFORCE_LE(outs.size(), 1UL,
56-
"Op %s output %s should contain only one variable", type_,
57-
name);
56+
"Operator %s's output %s should contain only one variable.",
57+
type_, name);
5858
return outs.empty() ? kEmptyVarName : outs[0];
5959
}
6060

6161
const std::vector<std::string>& OperatorBase::Outputs(
6262
const std::string& name) const {
6363
auto it = outputs_.find(name);
64-
PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output called %s",
65-
type_, name);
64+
PADDLE_ENFORCE(it != outputs_.end(),
65+
"Operator %s does not have an output called %s.", type_, name);
6666
return it->second;
6767
}
6868

paddle/framework/operator.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,8 @@ class OperatorWithKernel : public OperatorBase {
427427
int tmp = static_cast<int>(ToDataType(t->type()));
428428
VLOG(3) << "Input " << ipt_name << " with data_type " << tmp;
429429
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
430-
"DataType of Paddle Op %s must be same.", Type());
430+
"DataType of Paddle Op %s must be the same.",
431+
Type());
431432
data_type = tmp;
432433
}
433434
}

paddle/framework/tensor.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@ class Tensor {
118118
const platform::DeviceContext& ctx);
119119

120120
/**
121-
* @brief Return the slice of the tensor.
121+
* @brief Return a sub-tensor of the given tensor.
122122
*
123-
* @param[in] begin_idx The begin index of the slice.
124-
* @param[in] end_idx The end index of the slice.
123+
* @param[in] begin_idx The index of the start row(inclusive) to slice.
124+
* The index number begins from 0.
125+
* @param[in] end_idx The index of the end row(exclusive) to slice.
126+
* The index number begins from 0.
125127
*/
126128
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
127129

paddle/framework/tensor_impl.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,10 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
112112
if (holder_ != nullptr) {
113113
holder_->set_type(type);
114114
}
115-
PADDLE_ENFORCE_GT(numel(), 0,
116-
"Tensor's numel must be larger than zero to call "
117-
"Tensor::mutable_data. Call Tensor::set_dim first.");
115+
PADDLE_ENFORCE_GT(
116+
numel(), 0,
117+
"When calling this method, the Tensor's numel must be larger than zero. "
118+
"Please check Tensor::Resize has been called first.");
118119
int64_t size = numel() * SizeOfType(type);
119120
/* some versions of boost::variant don't have operator!= */
120121
if (holder_ == nullptr || !(holder_->place() == place) ||
@@ -229,10 +230,12 @@ inline void Tensor::CopyFromVector(const std::vector<T>& src,
229230

230231
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
231232
check_memory_size();
232-
PADDLE_ENFORCE_GE(begin_idx, 0, "Slice begin index is less than zero.");
233-
PADDLE_ENFORCE_LE(end_idx, dims_[0], "Slice end index is out of bound.");
234-
PADDLE_ENFORCE_LT(begin_idx, end_idx,
235-
"Begin index must be less than end index.");
233+
PADDLE_ENFORCE_GE(begin_idx, 0,
234+
"The start row index must be greater than 0.");
235+
PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is out of bound.");
236+
PADDLE_ENFORCE_LT(
237+
begin_idx, end_idx,
238+
"The start row index must be lesser than the end row index.");
236239

237240
if (dims_[0] == 1) {
238241
return *this;

paddle/gserver/layers/CRFLayer.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,10 @@ void CRFLayer::backward(const UpdateCallback& callback) {
101101
: real(1.0f);
102102
instanceWeight *= coeff_;
103103

104-
MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]);
105-
grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight);
104+
if (output.grad) {
105+
MatrixPtr grad = output.grad->subRowMatrix(starts[i], starts[i + 1]);
106+
grad->add(*crfs_[i].getXGrad(), real(1.0f), instanceWeight);
107+
}
106108
if (needWGrad) {
107109
weight_->getWGrad()->add(
108110
*crfs_[i].getWGrad(), real(1.0f), instanceWeight);

paddle/gserver/layers/LinearChainCRF.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ real LinearChainCRF::forward(real* x, int* s, int length) {
102102
}
103103

104104
void LinearChainCRF::backward(real* x, int* s, int length, bool needWGrad) {
105-
MatrixPtr matX = Matrix::create(x, length, numClasses_);
106105
Matrix::resizeOrCreate(matGrad_, length, numClasses_);
107106
Matrix::resizeOrCreate(beta_, length, numClasses_);
108107
real* b = b_->getData();

paddle/operators/cross_entropy_op.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,9 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
2828

2929
auto x_dims = ctx->GetInputDim("X");
3030
auto label_dims = ctx->GetInputDim("Label");
31-
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
32-
PADDLE_ENFORCE_EQ(label_dims.size(), 2, "Input(Label)'s rank should be 2.");
31+
PADDLE_ENFORCE_EQ(x_dims.size(), 2UL, "Input(X)'s rank should be 2.");
32+
PADDLE_ENFORCE_EQ(label_dims.size(), 2UL,
33+
"Input(Label)'s rank should be 2.");
3334
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0],
3435
"The 1st dimension of Input(X) and Input(Label) should "
3536
"be equal.");
@@ -38,8 +39,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
3839
"If Attr(soft_label) == true, the 2nd dimension of "
3940
"Input(X) and Input(Label) should be equal.");
4041
} else {
41-
PADDLE_ENFORCE_EQ(label_dims[1], 1,
42-
"If Attr(soft_label) == false, the 2nd dimension of "
42+
PADDLE_ENFORCE_EQ(label_dims[1], 1UL,
43+
"If Attr(softLabel) == false, the 2nd dimension of "
4344
"Input(Label) should be 1.");
4445
}
4546

@@ -48,7 +49,8 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
4849
}
4950

5051
protected:
51-
// CrossEntropy's data type just determined by "X"
52+
// Explicitly set that data type of the output of the cross_entropy operator
53+
// is determined by its input "X".
5254
framework::DataType IndicateDataType(
5355
const framework::ExecutionContext& ctx) const override {
5456
return framework::ToDataType(ctx.Input<Tensor>("X")->type());

0 commit comments

Comments
 (0)