Skip to content

Commit 10dd632

Browse files
committed
Rename 'ctc_greedy_decode' to 'ctc_decode'
1 parent 281e93b commit 10dd632

File tree

4 files changed

+15
-17
lines changed

4 files changed

+15
-17
lines changed

paddle/operators/ctc_decode_op.cc

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "paddle/operators/ctc_greedy_decode_op.h"
15+
#include "paddle/operators/ctc_decode_op.h"
1616

1717
namespace paddle {
1818
namespace operators {
1919

20-
class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
20+
class CTCDecodeOp : public framework::OperatorWithKernel {
2121
public:
2222
using framework::OperatorWithKernel::OperatorWithKernel;
2323

2424
void InferShape(framework::InferShapeContext* ctx) const override {
2525
PADDLE_ENFORCE(ctx->HasInput("Input"),
26-
"Input of CTCGreedyDecodeOp should not be null.");
26+
"Input of CTCDecodeOp should not be null.");
2727
PADDLE_ENFORCE(ctx->HasOutput("Output"),
28-
"Output of CTCGreedyDecodeOp should not be null.");
28+
"Output of CTCDecodeOp should not be null.");
2929

3030
auto input_dims = ctx->GetInputDim("Input");
3131

@@ -42,9 +42,9 @@ class CTCGreedyDecodeOp : public framework::OperatorWithKernel {
4242
}
4343
};
4444

45-
class CTCGreedyDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
45+
class CTCDecodeOpMaker : public framework::OpProtoAndCheckerMaker {
4646
public:
47-
CTCGreedyDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
47+
CTCDecodeOpMaker(OpProto* proto, OpAttrChecker* op_checker)
4848
: OpProtoAndCheckerMaker(proto, op_checker) {
4949
AddInput("Input",
5050
"(LodTensor, default: LoDTensor<int>), Its shape is "
@@ -86,9 +86,7 @@ and then delete all blanks in sequence.
8686
} // namespace paddle
8787

8888
namespace ops = paddle::operators;
89-
REGISTER_OPERATOR(ctc_greedy_decode, ops::CTCGreedyDecodeOp,
90-
ops::CTCGreedyDecodeOpMaker,
89+
REGISTER_OPERATOR(ctc_decode, ops::CTCDecodeOp, ops::CTCDecodeOpMaker,
9190
paddle::framework::EmptyGradOpMaker);
9291
REGISTER_OP_CPU_KERNEL(
93-
ctc_greedy_decode,
94-
ops::CTCGreedyDecodeKernel<paddle::platform::CPUDeviceContext, int>);
92+
ctc_decode, ops::CTCDecodeKernel<paddle::platform::CPUDeviceContext, int>);

paddle/operators/ctc_decode_op.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ limitations under the License. */
1515
#include <stdio.h>
1616
#include <thrust/device_vector.h>
1717
#include <thrust/host_vector.h>
18-
#include "paddle/operators/ctc_greedy_decode_op.h"
18+
#include "paddle/operators/ctc_decode_op.h"
1919

2020
namespace paddle {
2121
namespace operators {
@@ -42,7 +42,7 @@ __global__ void MergeAndDelCudaKernel(const int64_t num_token, const T* tokens,
4242
}
4343

4444
template <typename T>
45-
class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
45+
class CTCDecodeOpCUDAKernel : public framework::OpKernel<T> {
4646
public:
4747
void Compute(const framework::ExecutionContext& ctx) const override {
4848
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
@@ -87,5 +87,5 @@ class CTCGreedyDecodeOpCUDAKernel : public framework::OpKernel<T> {
8787
} // namespace operators
8888
} // namespace paddle
8989

90-
REGISTER_OP_CUDA_KERNEL(ctc_greedy_decode,
91-
paddle::operators::CTCGreedyDecodeOpCUDAKernel<int>);
90+
REGISTER_OP_CUDA_KERNEL(ctc_decode,
91+
paddle::operators::CTCDecodeOpCUDAKernel<int>);

paddle/operators/ctc_decode_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Tensor = framework::Tensor;
2323
using LoDTensor = framework::LoDTensor;
2424

2525
template <typename DeviceContext, typename T>
26-
class CTCGreedyDecodeKernel : public framework::OpKernel<T> {
26+
class CTCDecodeKernel : public framework::OpKernel<T> {
2727
public:
2828
void Compute(const framework::ExecutionContext& ctx) const override {
2929
auto* input = ctx.Input<LoDTensor>("Input");

python/paddle/v2/fluid/tests/test_ctc_decode.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def CTCDecode(input, lod, blank, merge_repeated):
2222

2323
class TestCTCDecodeOp(OpTest):
2424
def config(self):
25-
self.op_type = "ctc_greedy_decode"
25+
self.op_type = "ctc_decode"
2626
self.input_lod = [[0, 11, 18]]
2727
self.blank = 0
2828
self.merge_repeated = False
@@ -49,7 +49,7 @@ def test_check_output(self):
4949

5050
class TestCTCDecodeOpCase1(TestCTCDecodeOp):
5151
def config(self):
52-
self.op_type = "ctc_greedy_decode"
52+
self.op_type = "ctc_decode"
5353
self.input_lod = [[0, 11, 18]]
5454
self.blank = 0
5555
self.merge_repeated = True

0 commit comments

Comments
 (0)