Skip to content

Commit 735737d

Browse files
committed
initialize crf opreator.
1 parent cf92802 commit 735737d

File tree

3 files changed

+102
-0
lines changed

3 files changed

+102
-0
lines changed

paddle/operators/crf_op.cc

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/crf_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class CrfOpMaker : public framework::OpProtoAndCheckerMaker {
21+
public:
22+
CrfOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
23+
: OpProtoAndCheckerMaker(proto, op_checker) {}
24+
};
25+
26+
class CrfOp : public framework::OperatorWithKernel {
27+
public:
28+
using framework::OperatorWithKernel::OperatorWithKernel;
29+
30+
protected:
31+
void InferShape(framework::InferShapeContextBase* ctx) const override {}
32+
};
33+
34+
class CrfGradOp : public framework::OperatorWithKernel {
35+
public:
36+
using framework::OperatorWithKernel::OperatorWithKernel;
37+
38+
protected:
39+
void InferShape(framework::InferShapeContextBase* ctx) const override {}
40+
};
41+
42+
} // namespace operators
43+
} // namespace paddle
44+
45+
namespace ops = paddle::operators;
46+
REGISTER_OP(crf, ops::CrfOp, ops::CrfOpMaker, crf_grad, ops::CrfGradOp);
47+
REGISTER_OP_CPU_KERNEL(crf, ops::CrfOpKernel<float>);
48+
REGISTER_OP_CPU_KERNEL(crf_grad, ops::CrfGradOpKernel<float>);

paddle/operators/crf_op.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/eigen.h"
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename T>
23+
class CrfOpKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& ctx) const override {
26+
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
27+
"This kernel only runs on CPU.");
28+
}
29+
};
30+
31+
template <typename T>
32+
class CrfGradOpKernel : public framework::OpKernel<T> {
33+
public:
34+
void Compute(const framework::ExecutionContext& ctx) const override {
35+
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
36+
"This kernel only runs on CPU.");
37+
}
38+
};
39+
40+
} // namespace operators
41+
} // namespace paddle
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import unittest
2+
import numpy as np
3+
4+
5+
class TestCrfOp(OpTest):
6+
def setUp(self):
7+
self.op_type = "crf"
8+
batch_size = 3
9+
class_num = 37
10+
11+
12+
if __name__ == "__main__":
13+
unittest.main()

0 commit comments

Comments
 (0)