Skip to content

Commit b32c13d

Browse files
authored
Add cudnn ctc loss (#12366)
* add cudnn ctc loss * wip add test test=develop * wip * wip * done test=develop * move include cudnn test=develop * test test=develop * fix build test=develop * fix build test=develop * fix build on cudnn5 test=develop * fix cudnn5 build test=develop * fix cudnn5 build test=develop * merge develop softmax functor change test=develop
1 parent b984c70 commit b32c13d

File tree

8 files changed

+279
-8
lines changed

8 files changed

+279
-8
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ paddle.fluid.layers.edit_distance ArgSpec(args=['input', 'label', 'normalized',
9393
paddle.fluid.layers.l2_normalize ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None))
9494
paddle.fluid.layers.matmul ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None))
9595
paddle.fluid.layers.topk ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,))
96-
paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times'], varargs=None, keywords=None, defaults=(0, False))
96+
paddle.fluid.layers.warpctc ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, False, False))
9797
paddle.fluid.layers.sequence_reshape ArgSpec(args=['input', 'new_dim'], varargs=None, keywords=None, defaults=None)
9898
paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, keywords=None, defaults=(None,))
9999
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))

paddle/fluid/operators/CMakeLists.txt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,6 @@ if (NOT WIN32)
300300
op_library(gru_op DEPS sequence2batch gru_compute)
301301
endif(NOT WIN32)
302302
op_library(recurrent_op DEPS executor)
303-
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
304303
op_library(cos_sim_op DEPS cos_sim_functor)
305304
op_library(parallel_do_op DEPS executor)
306305
op_library(unsqueeze_op DEPS reshape_op)
@@ -331,6 +330,14 @@ op_library(load_combine_op DEPS lod_tensor)
331330
op_library(concat_op DEPS concat_and_split)
332331
op_library(tensor_array_to_tensor_op DEPS concat_op)
333332

333+
set(DEPS_OPS ${DEPS_OPS} warpctc_op)
334+
if (WITH_GPU)
335+
if (${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
336+
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc warpctc_op.cu.cc)
337+
endif()
338+
endif()
339+
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
340+
334341
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
335342

336343
foreach(src ${GENERAL_OPS})
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/mixed_vector.h"
16+
#include "paddle/fluid/operators/math/softmax.h"
17+
#include "paddle/fluid/operators/warpctc_op.h"
18+
#include "paddle/fluid/platform/cudnn_helper.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
#if CUDNN_VERSION >= 7001
24+
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
25+
using ScopedCTCLossDescriptor = platform::ScopedCTCLossDescriptor;
26+
using DataLayout = platform::DataLayout;
27+
28+
template <typename DeviceContext, typename T>
29+
class CudnnCTCKernel : public framework::OpKernel<T> {
30+
public:
31+
void Compute(const framework::ExecutionContext& ctx) const override {
32+
// =====================Copied code from warpctc===========================
33+
auto* logits = ctx.Input<LoDTensor>("Logits");
34+
auto* label = ctx.Input<LoDTensor>("Label");
35+
auto* warpctc_grad = ctx.Output<LoDTensor>("WarpCTCGrad");
36+
auto* loss = ctx.Output<LoDTensor>("Loss");
37+
38+
const size_t level = 0;
39+
40+
auto logits_lod = framework::ToAbsOffset(logits->lod());
41+
auto logits_dims = logits->dims();
42+
PADDLE_ENFORCE_EQ(logits_dims[0],
43+
static_cast<int64_t>(logits_lod[level].back()),
44+
"The first dimension of Input(Logits) should be equal to "
45+
"the sum of all sequences' lengths.");
46+
47+
auto label_lod = framework::ToAbsOffset(label->lod());
48+
auto label_dims = label->dims();
49+
PADDLE_ENFORCE_EQ(
50+
label_dims[0], label->numel(),
51+
"The width of each timestep in Input(Label) should be 1.");
52+
53+
const size_t num_sequences = logits_lod[level].size() - 1;
54+
PADDLE_ENFORCE_EQ(num_sequences, label_lod[level].size() - 1,
55+
"The number of sequences of Input(Logits) should be "
56+
"equal to that of Input(Label).");
57+
PADDLE_ENFORCE_LE(num_sequences, 256,
58+
"The labelLengths must less than 256 for cudnn call.");
59+
60+
const size_t sequence_width = logits->numel() / logits_dims[0];
61+
auto loss_dims =
62+
framework::make_ddim({static_cast<int64_t>(num_sequences), 1});
63+
64+
// NOTE: cudnn takes softmax input, calculate softmax first, then do padding
65+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
66+
LoDTensor softmax_logits;
67+
softmax_logits.mutable_data<T>(logits->dims(), ctx.GetPlace());
68+
softmax_logits.set_lod(logits_lod);
69+
int rank = logits->dims().size();
70+
Tensor in_2d = framework::ReshapeToMatrix(*logits, rank - 1);
71+
Tensor out_2d = framework::ReshapeToMatrix(softmax_logits, rank - 1);
72+
math::SoftmaxFunctor<DeviceContext, T, false>()(dev_ctx, &in_2d, &out_2d);
73+
74+
// ctc needs sequences data stored in transposed padding format
75+
// logits and grad using padding data of layout 'TNC'
76+
// T: max_sequence_length
77+
// N: batch_size (num_sequences)
78+
// C: width
79+
LoDTensor warpctc_logits;
80+
const size_t max_sequence_length =
81+
math::MaximumSequenceLength(logits_lod[level]);
82+
auto warpctc_logits_dims =
83+
framework::make_ddim({static_cast<int64_t>(max_sequence_length),
84+
static_cast<int64_t>(num_sequences),
85+
static_cast<int64_t>(sequence_width)});
86+
warpctc_logits.mutable_data<T>(warpctc_logits_dims, ctx.GetPlace());
87+
88+
LoDTensor cpu_pad_value;
89+
T* pad_value_data =
90+
cpu_pad_value.mutable_data<T>({1}, platform::CPUPlace());
91+
*pad_value_data = static_cast<T>(0);
92+
LoDTensor pad_value;
93+
if (platform::is_cpu_place(ctx.GetPlace())) {
94+
pad_value = cpu_pad_value;
95+
} else {
96+
TensorCopySync(cpu_pad_value, ctx.GetPlace(), &pad_value);
97+
}
98+
99+
math::PaddingLoDTensorFunctor<DeviceContext, T>()(
100+
ctx.template device_context<DeviceContext>(), softmax_logits,
101+
&warpctc_logits, pad_value, -1, 0, false /* norm_by_times */,
102+
math::kLengthBatchWidth);
103+
const T* warpctc_logits_data = warpctc_logits.data<T>();
104+
105+
std::vector<int> warpctc_label_lengths(num_sequences);
106+
std::vector<int> warpctc_logits_lengths(num_sequences);
107+
108+
for (size_t i = 0; i < num_sequences; ++i) {
109+
warpctc_label_lengths[i] = label_lod[level][i + 1] - label_lod[level][i];
110+
warpctc_logits_lengths[i] =
111+
logits_lod[level][i + 1] - logits_lod[level][i];
112+
}
113+
114+
T* warpctc_grad_data =
115+
warpctc_grad->mutable_data<T>(warpctc_logits.dims(), ctx.GetPlace());
116+
117+
math::SetConstant<DeviceContext, T>()(
118+
ctx.template device_context<DeviceContext>(), warpctc_grad,
119+
static_cast<T>(0));
120+
121+
Tensor warpctc_label;
122+
TensorCopySync(*label, platform::CPUPlace(), &warpctc_label);
123+
const int* warpctc_label_data = warpctc_label.data<int>();
124+
// ========================================================================
125+
126+
ScopedTensorDescriptor logits_desc;
127+
ScopedTensorDescriptor grad_desc;
128+
ScopedCTCLossDescriptor ctcloss_desc;
129+
// layout here doesn't have effect.
130+
DataLayout layout = DataLayout::kNCHW;
131+
132+
auto cu_logits_desc = logits_desc.descriptor<T>(
133+
layout, framework::vectorize2int(warpctc_logits.dims()));
134+
auto cu_grad_desc = grad_desc.descriptor<T>(
135+
layout, framework::vectorize2int(warpctc_grad->dims()));
136+
auto cu_ctcloss_desc = ctcloss_desc.descriptor<T>();
137+
138+
auto handle = dev_ctx.cudnn_handle();
139+
size_t workspace_size;
140+
141+
CUDNN_ENFORCE(platform::dynload::cudnnGetCTCLossWorkspaceSize(
142+
handle, cu_logits_desc, cu_grad_desc, warpctc_label_data,
143+
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
144+
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size));
145+
146+
T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace());
147+
148+
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
149+
auto cudnn_func = [&](void* cudnn_workspace) {
150+
CUDNN_ENFORCE(platform::dynload::cudnnCTCLoss(
151+
handle, cu_logits_desc, warpctc_logits_data, warpctc_label_data,
152+
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
153+
loss_data, cu_grad_desc, warpctc_grad_data,
154+
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, cudnn_workspace,
155+
workspace_size));
156+
};
157+
workspace_handle.RunFunc(cudnn_func, workspace_size);
158+
}
159+
};
160+
161+
template <typename DeviceContext, typename T>
162+
class CudnnCTCGradKernel : public framework::OpKernel<T> {
163+
public:
164+
void Compute(const framework::ExecutionContext& ctx) const override {
165+
auto* warpctc_grad = ctx.Input<LoDTensor>("WarpCTCGrad");
166+
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
167+
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
168+
169+
logits_grad->mutable_data<T>(ctx.GetPlace());
170+
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
171+
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
172+
ctx.template device_context<DeviceContext>(), *warpctc_grad,
173+
logits_grad, -1, 0, norm_by_times, math::kLengthBatchWidth);
174+
175+
const T* loss_grad_data = loss_grad->data<T>();
176+
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
177+
ctx.template device_context<DeviceContext>(), loss_grad_data,
178+
logits_grad);
179+
}
180+
};
181+
182+
#endif
183+
} // namespace operators
184+
} // namespace paddle
185+
186+
namespace ops = paddle::operators;
187+
namespace plat = paddle::platform;
188+
#if CUDNN_VERSION >= 7001
189+
REGISTER_OP_KERNEL(
190+
warpctc, CUDNN, plat::CUDAPlace,
191+
ops::CudnnCTCKernel<paddle::platform::CUDADeviceContext, float>);
192+
REGISTER_OP_KERNEL(
193+
warpctc_grad, CUDNN, plat::CUDAPlace,
194+
ops::CudnnCTCGradKernel<paddle::platform::CUDADeviceContext, float>);
195+
#endif

paddle/fluid/operators/warpctc_op.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/warpctc_op.h"
1616

17+
#ifdef PADDLE_WITH_CUDA
18+
#include "paddle/fluid/platform/cudnn_helper.h"
19+
#endif
20+
1721
namespace paddle {
1822
namespace operators {
1923

@@ -45,9 +49,16 @@ class WarpCTCOp : public framework::OperatorWithKernel {
4549
protected:
4650
framework::OpKernelType GetExpectedKernelType(
4751
const framework::ExecutionContext& ctx) const override {
52+
framework::LibraryType library_{framework::LibraryType::kPlain};
53+
#ifdef PADDLE_WITH_CUDA
54+
if (platform::CanCUDNNBeUsed(ctx)) {
55+
library_ = framework::LibraryType::kCUDNN;
56+
}
57+
#endif
58+
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
4859
return framework::OpKernelType(
4960
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
50-
ctx.device_context());
61+
ctx.device_context(), layout_, library_);
5162
}
5263
};
5364

@@ -86,6 +97,10 @@ class WarpCTCOpMaker : public framework::OpProtoAndCheckerMaker {
8697
"normalize the gradients by the number of time-step, "
8798
"which is also the sequence's length.")
8899
.SetDefault(false);
100+
AddAttr<bool>("use_cudnn",
101+
"(bool, default: false), whether to "
102+
"use cudnn kernel.")
103+
.SetDefault(false);
89104
AddComment(R"DOC(
90105
An operator integrating the open-source
91106
[warp-ctc](https://github.com/baidu-research/warp-ctc) library, which is used in

paddle/fluid/platform/cudnn_helper.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,5 +380,28 @@ inline bool CanCUDNNBeUsed(const framework::ExecutionContext& ctx) {
380380
return use_cudnn;
381381
}
382382

383+
#if CUDNN_VERSION >= 7001
384+
class ScopedCTCLossDescriptor {
385+
public:
386+
ScopedCTCLossDescriptor() {
387+
PADDLE_ENFORCE(dynload::cudnnCreateCTCLossDescriptor(&desc_));
388+
}
389+
~ScopedCTCLossDescriptor() {
390+
PADDLE_ENFORCE(dynload::cudnnDestroyCTCLossDescriptor(desc_));
391+
}
392+
393+
template <typename T>
394+
inline cudnnCTCLossDescriptor_t descriptor() {
395+
PADDLE_ENFORCE(
396+
dynload::cudnnSetCTCLossDescriptor(desc_, CudnnDataType<T>::type));
397+
return desc_;
398+
}
399+
400+
private:
401+
cudnnCTCLossDescriptor_t desc_;
402+
DISABLE_COPY_AND_ASSIGN(ScopedCTCLossDescriptor);
403+
};
404+
#endif
405+
383406
} // namespace platform
384407
} // namespace paddle

paddle/fluid/platform/dynload/cudnn.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,13 @@ CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
154154
#if CUDNN_VERSION >= 7001
155155
#define CUDNN_DNN_ROUTINE_EACH_R7(__macro) \
156156
__macro(cudnnSetConvolutionGroupCount); \
157-
__macro(cudnnSetConvolutionMathType);
157+
__macro(cudnnSetConvolutionMathType); \
158+
__macro(cudnnCreateCTCLossDescriptor); \
159+
__macro(cudnnDestroyCTCLossDescriptor); \
160+
__macro(cudnnGetCTCLossDescriptor); \
161+
__macro(cudnnSetCTCLossDescriptor); \
162+
__macro(cudnnGetCTCLossWorkspaceSize); \
163+
__macro(cudnnCTCLoss);
158164
CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
159165
#endif
160166

python/paddle/fluid/layers/nn.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4187,7 +4187,7 @@ def ctc_greedy_decoder(input, blank, name=None):
41874187
return ctc_out
41884188

41894189

4190-
def warpctc(input, label, blank=0, norm_by_times=False):
4190+
def warpctc(input, label, blank=0, norm_by_times=False, use_cudnn=False):
41914191
"""
41924192
An operator integrating the open source Warp-CTC library
41934193
(https://github.com/baidu-research/warp-ctc)
@@ -4212,6 +4212,7 @@ def warpctc(input, label, blank=0, norm_by_times=False):
42124212
by the number of time-step, which is also the sequence's length.
42134213
There is no need to normalize the gradients if warpctc layer was
42144214
follewed by a mean_op.
4215+
use_cudnn (bool, default false): Whether to use cudnn.
42154216
42164217
Returns:
42174218
Variable: The Connectionist Temporal Classification (CTC) loss,
@@ -4235,8 +4236,11 @@ def warpctc(input, label, blank=0, norm_by_times=False):
42354236
'Label': [label]},
42364237
outputs={'WarpCTCGrad': [grad_out],
42374238
'Loss': [loss_out]},
4238-
attrs={'blank': blank,
4239-
'norm_by_times': norm_by_times})
4239+
attrs={
4240+
'blank': blank,
4241+
'norm_by_times': norm_by_times,
4242+
'use_cudnn': use_cudnn
4243+
})
42404244
return loss_out
42414245

42424246

python/paddle/fluid/tests/unittests/test_warpctc_op.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def config(self):
183183
self.labels_lod = [[3, 1, 4, 4]]
184184
self.blank = self.num_classes - 1
185185
self.norm_by_times = False
186+
self.use_cudnn = False
186187

187188
def setUp(self):
188189
self.op_type = "warpctc"
@@ -215,7 +216,11 @@ def setUp(self):
215216
"Label": (labels, self.labels_lod)
216217
}
217218
self.outputs = {"Loss": loss}
218-
self.attrs = {"blank": self.blank, "norm_by_times": self.norm_by_times}
219+
self.attrs = {
220+
"blank": self.blank,
221+
"norm_by_times": self.norm_by_times,
222+
"use_cudnn": self.use_cudnn
223+
}
219224

220225
def test_check_output(self):
221226
self.check_output()
@@ -233,6 +238,22 @@ def config(self):
233238
self.labels_lod = [[3, 1, 4, 4]]
234239
self.blank = 0
235240
self.norm_by_times = False
241+
self.use_cudnn = False
242+
243+
244+
class TestCudnnCTCOp(TestWarpCTCOp):
245+
def config(self):
246+
self.batch_size = 4
247+
self.num_classes = 8
248+
self.logits_lod = [[4, 1, 3, 3]]
249+
self.labels_lod = [[3, 1, 4, 4]]
250+
self.blank = 0
251+
self.norm_by_times = False
252+
self.use_cudnn = True
253+
254+
def test_check_grad(self):
255+
self.outputs['WarpCTCGrad'] = self.gradient
256+
self.check_grad(["Logits"], "Loss", max_relative_error=0.01)
236257

237258

238259
if __name__ == "__main__":

0 commit comments

Comments
 (0)