Skip to content

Commit 128adf5

Browse files
authored
[Speed]implement cudnn sequence softmax cudnn (#8978)
* "add softmax cudnn functor support" * "add testing" * "refine cmakelist" * "sequence softmax forward speed up" * "add softmax grad" * "fix sequence softmax test" * "add double precision' * "fix softmax test" * "add softmax cudnn support" * "fix softmax cudnn test" * "add softmax to nn.py" * "fix compile bug" * "refine cmakelist" * "fix ci" * "fix based on comment" * "fix based on comments" * "fix ci"
1 parent 9b9f3f0 commit 128adf5

File tree

14 files changed

+481
-19
lines changed

14 files changed

+481
-19
lines changed

paddle/fluid/operators/math/softmax.cu

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,86 @@ limitations under the License. */
1414

1515
#define EIGEN_USE_GPU
1616

17+
#include "paddle/fluid/operators/math/math_function.h"
1718
#include "paddle/fluid/operators/math/softmax.h"
1819
#include "paddle/fluid/operators/math/softmax_impl.h"
20+
#include "paddle/fluid/platform/cudnn_helper.h"
1921

2022
namespace paddle {
2123
namespace operators {
2224
namespace math {
2325

26+
using Tensor = framework::Tensor;
27+
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
28+
using DataLayout = platform::DataLayout;
29+
template <typename T>
30+
using CudnnDataType = platform::CudnnDataType<T>;
31+
32+
template <typename T>
33+
void SoftmaxCUDNNFunctor<T>::operator()(
34+
const platform::CUDADeviceContext& context, const framework::Tensor* X,
35+
framework::Tensor* Y) {
36+
// ------------------- cudnn descriptors ---------------------
37+
ScopedTensorDescriptor xDesc;
38+
ScopedTensorDescriptor yDesc;
39+
std::vector<int> cudnn_tensor_dims = framework::vectorize2int(X->dims());
40+
DataLayout layout = DataLayout::kNCHW;
41+
if (cudnn_tensor_dims.size() == 5) {
42+
layout = DataLayout::kNCDHW;
43+
}
44+
// NOTE(*) : cudnn softmax only support >= 4D Tensor,
45+
// fill 1 at unused dims
46+
if (cudnn_tensor_dims.size() <= 2) {
47+
cudnn_tensor_dims.resize(4, 1);
48+
}
49+
cudnnTensorDescriptor_t cudnn_x_desc =
50+
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
51+
cudnnTensorDescriptor_t cudnn_y_desc =
52+
xDesc.descriptor<T>(layout, cudnn_tensor_dims);
53+
PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxForward(
54+
context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE,
55+
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_x_desc,
56+
X->data<T>(), CudnnDataType<T>::kZero(), cudnn_y_desc,
57+
Y->mutable_data<T>(context.GetPlace())));
58+
}
59+
60+
template <typename T>
61+
void SoftmaxGradCUDNNFunctor<T>::operator()(
62+
const platform::CUDADeviceContext& context, const framework::Tensor* Y,
63+
const framework::Tensor* YGrad, framework::Tensor* XGrad) {
64+
// ------------------- cudnn descriptors ---------------------
65+
ScopedTensorDescriptor yDesc;
66+
ScopedTensorDescriptor dyDesc;
67+
ScopedTensorDescriptor dxDesc;
68+
std::vector<int> cudnn_tensor_dims = framework::vectorize2int(Y->dims());
69+
DataLayout layout = DataLayout::kNCHW;
70+
if (cudnn_tensor_dims.size() == 5) {
71+
layout = DataLayout::kNCDHW;
72+
}
73+
// NOTE(*) : cudnn softmax only support >= 4D Tensor,
74+
// fill 1 at unused dims
75+
if (cudnn_tensor_dims.size() <= 2) {
76+
cudnn_tensor_dims.resize(4, 1);
77+
}
78+
cudnnTensorDescriptor_t cudnn_y_desc =
79+
yDesc.descriptor<T>(layout, cudnn_tensor_dims);
80+
cudnnTensorDescriptor_t cudnn_xgrad_desc =
81+
dxDesc.descriptor<T>(layout, cudnn_tensor_dims);
82+
cudnnTensorDescriptor_t cudnn_ygrad_desc =
83+
dyDesc.descriptor<T>(layout, cudnn_tensor_dims);
84+
PADDLE_ENFORCE(platform::dynload::cudnnSoftmaxBackward(
85+
context.cudnn_handle(), CUDNN_SOFTMAX_ACCURATE,
86+
CUDNN_SOFTMAX_MODE_INSTANCE, CudnnDataType<T>::kOne(), cudnn_y_desc,
87+
Y->data<T>(), cudnn_ygrad_desc, YGrad->data<T>(),
88+
CudnnDataType<T>::kZero(), cudnn_xgrad_desc,
89+
XGrad->mutable_data<T>(context.GetPlace())));
90+
}
91+
92+
template class SoftmaxCUDNNFunctor<float>;
93+
template class SoftmaxCUDNNFunctor<double>;
94+
template class SoftmaxGradCUDNNFunctor<float>;
95+
template class SoftmaxGradCUDNNFunctor<double>;
96+
2497
template class SoftmaxFunctor<platform::CUDADeviceContext, float>;
2598
template class SoftmaxFunctor<platform::CUDADeviceContext, double>;
2699
template class SoftmaxGradFunctor<platform::CUDADeviceContext, float>;

paddle/fluid/operators/math/softmax.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ class SoftmaxGradFunctor {
3333
const framework::Tensor* y_grad, framework::Tensor* x_grad);
3434
};
3535

36+
#ifdef PADDLE_WITH_CUDA
37+
template <typename T>
38+
class SoftmaxCUDNNFunctor {
39+
public:
40+
void operator()(const platform::CUDADeviceContext& context,
41+
const framework::Tensor* X, framework::Tensor* Y);
42+
};
43+
44+
template <typename T>
45+
class SoftmaxGradCUDNNFunctor {
46+
public:
47+
void operator()(const platform::CUDADeviceContext& context,
48+
const framework::Tensor* Y, const framework::Tensor* y_grad,
49+
framework::Tensor* x_grad);
50+
};
51+
#endif
52+
3653
} // namespace math
3754
} // namespace operators
3855
} // namespace paddle
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/math/math_function.h"
17+
#include "paddle/fluid/operators/math/softmax.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
using LoDTensor = framework::LoDTensor;
24+
25+
template <typename T>
26+
class SequenceSoftmaxCUDNNKernel : public framework::OpKernel<T> {
27+
public:
28+
void Compute(const framework::ExecutionContext& ctx) const override {
29+
auto* x = ctx.Input<LoDTensor>("X");
30+
auto* out = ctx.Output<LoDTensor>("Out");
31+
32+
auto lod = x->lod();
33+
auto dims = x->dims();
34+
35+
const size_t level = lod.size() - 1;
36+
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
37+
"The first dimension of Input(X) should be equal to the "
38+
"sum of all sequences' lengths.");
39+
PADDLE_ENFORCE_EQ(dims[0], x->numel(),
40+
"The width of each timestep in Input(X) of "
41+
"SequenceSoftmaxOp should be 1.");
42+
43+
out->mutable_data<T>(ctx.GetPlace());
44+
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
45+
int start_pos = static_cast<int>(lod[level][i]);
46+
int end_pos = static_cast<int>(lod[level][i + 1]);
47+
Tensor x_i = x->Slice(start_pos, end_pos);
48+
Tensor out_i = out->Slice(start_pos, end_pos);
49+
50+
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
51+
framework::DDim dims_i =
52+
// framework::make_ddim({1UL, end_pos - start_pos, 1UL, 1UL});
53+
framework::make_ddim({1UL, end_pos - start_pos});
54+
x_i.Resize(dims_i);
55+
out_i.Resize(dims_i);
56+
math::SoftmaxCUDNNFunctor<T>()(
57+
ctx.template device_context<platform::CUDADeviceContext>(), &x_i,
58+
&out_i);
59+
}
60+
}
61+
};
62+
63+
template <typename T>
64+
class SequenceSoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
65+
public:
66+
void Compute(const framework::ExecutionContext& ctx) const override {
67+
auto* out = ctx.Input<LoDTensor>("Out");
68+
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
69+
auto* x = ctx.Input<LoDTensor>("X");
70+
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
71+
72+
auto lod = x->lod();
73+
const size_t level = lod.size() - 1;
74+
75+
x_grad->mutable_data<T>(ctx.GetPlace());
76+
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
77+
int start_pos = static_cast<int>(lod[level][i]);
78+
int end_pos = static_cast<int>(lod[level][i + 1]);
79+
80+
Tensor out_i = out->Slice(start_pos, end_pos);
81+
Tensor out_grad_i = out_grad->Slice(start_pos, end_pos);
82+
Tensor x_grad_i = x_grad->Slice(start_pos, end_pos);
83+
84+
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
85+
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
86+
out_i.Resize(dims_i);
87+
out_grad_i.Resize(dims_i);
88+
x_grad_i.Resize(dims_i);
89+
math::SoftmaxGradCUDNNFunctor<T>()(
90+
ctx.template device_context<platform::CUDADeviceContext>(), &out_i,
91+
&out_grad_i, &x_grad_i);
92+
}
93+
}
94+
};
95+
96+
} // namespace operators
97+
} // namespace paddle
98+
99+
namespace ops = paddle::operators;
100+
REGISTER_OP_KERNEL(sequence_softmax, CUDNN, ::paddle::platform::CUDAPlace,
101+
ops::SequenceSoftmaxCUDNNKernel<float>,
102+
ops::SequenceSoftmaxCUDNNKernel<double>)
103+
REGISTER_OP_KERNEL(sequence_softmax_grad, CUDNN, ::paddle::platform::CUDAPlace,
104+
ops::SequenceSoftmaxGradCUDNNKernel<float>,
105+
ops::SequenceSoftmaxGradCUDNNKernel<double>)

paddle/fluid/operators/sequence_softmax_op.cc

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,29 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
2929
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
3030
ctx->ShareLoD("X", /*->*/ "Out");
3131
}
32+
33+
protected:
34+
framework::OpKernelType GetExpectedKernelType(
35+
const framework::ExecutionContext& ctx) const override {
36+
// choose cudnn kernel if the runtime supported.
37+
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
38+
bool runtime_cudnn_support = false;
39+
#ifdef PADDLE_WITH_CUDA
40+
if (platform::is_gpu_place(ctx.GetPlace())) {
41+
auto& dev_ctx =
42+
ctx.template device_context<platform::CUDADeviceContext>();
43+
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
44+
}
45+
#endif
46+
framework::LibraryType library_ = framework::LibraryType::kPlain;
47+
if (use_cudnn && runtime_cudnn_support) {
48+
library_ = framework::LibraryType::kCUDNN;
49+
}
50+
std::string data_format = ctx.Attr<std::string>("data_format");
51+
return framework::OpKernelType(
52+
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
53+
framework::StringToDataLayout(data_format), library_);
54+
}
3255
};
3356

3457
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
@@ -41,6 +64,17 @@ class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
4164
AddOutput("Out",
4265
"(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension "
4366
"of length 1.");
67+
AddAttr<bool>(
68+
"use_cudnn",
69+
"(bool, default false) Only used in cudnn kernel, need install cudnn")
70+
.SetDefault(false);
71+
AddAttr<std::string>(
72+
"data_format",
73+
"(string, default NCHW) Only used in "
74+
"An optional string from: \"NHWC\", \"NCHW\". "
75+
"Defaults to \"NHWC\". Specify the data format of the output data, "
76+
"the input will be transformed automatically. ")
77+
.SetDefault("AnyLayout");
4478
AddComment(R"DOC(
4579
Sequence Softmax Operator.
4680
@@ -91,6 +125,29 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
91125

92126
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
93127
}
128+
129+
protected:
130+
framework::OpKernelType GetExpectedKernelType(
131+
const framework::ExecutionContext& ctx) const override {
132+
// choose cudnn kernel if the runtime supported.
133+
bool use_cudnn = ctx.Attr<bool>("use_cudnn");
134+
bool runtime_cudnn_support = false;
135+
#ifdef PADDLE_WITH_CUDA
136+
if (platform::is_gpu_place(ctx.GetPlace())) {
137+
auto& dev_ctx =
138+
ctx.template device_context<platform::CUDADeviceContext>();
139+
runtime_cudnn_support = dev_ctx.cudnn_handle() != nullptr ? true : false;
140+
}
141+
#endif
142+
framework::LibraryType library_ = framework::LibraryType::kPlain;
143+
if (use_cudnn && runtime_cudnn_support) {
144+
library_ = framework::LibraryType::kCUDNN;
145+
}
146+
std::string data_format = ctx.Attr<std::string>("data_format");
147+
return framework::OpKernelType(
148+
framework::ToDataType(ctx.Input<Tensor>("X")->type()), ctx.GetPlace(),
149+
framework::StringToDataLayout(data_format), library_);
150+
}
94151
};
95152

96153
} // namespace operators
@@ -102,7 +159,9 @@ REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp,
102159
ops::SequenceSoftmaxGradOp);
103160
REGISTER_OP_CPU_KERNEL(
104161
sequence_softmax,
105-
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, float>);
162+
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
163+
ops::SequenceSoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
106164
REGISTER_OP_CPU_KERNEL(
107165
sequence_softmax_grad,
108-
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>);
166+
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
167+
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);

paddle/fluid/operators/sequence_softmax_op.cu.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ limitations under the License. */
1717
namespace ops = paddle::operators;
1818
REGISTER_OP_CUDA_KERNEL(
1919
sequence_softmax,
20-
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>)
20+
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, float>,
21+
ops::SequenceSoftmaxKernel<paddle::platform::CUDADeviceContext, double>)
2122
REGISTER_OP_CUDA_KERNEL(
2223
sequence_softmax_grad,
23-
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>);
24+
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext, float>,
25+
ops::SequenceSoftmaxGradKernel<paddle::platform::CUDADeviceContext,
26+
double>);
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/math/softmax.h"
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
23+
template <typename T>
24+
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& context) const override {
27+
auto* X = context.Input<Tensor>("X");
28+
auto* Out = context.Output<Tensor>("Out");
29+
30+
// allocate memory on device.
31+
Out->mutable_data<T>(context.GetPlace());
32+
33+
math::SoftmaxCUDNNFunctor<T>()(
34+
context.template device_context<platform::CUDADeviceContext>(), X, Out);
35+
}
36+
};
37+
38+
template <typename T>
39+
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
40+
public:
41+
void Compute(const framework::ExecutionContext& context) const override {
42+
auto* Out = context.Input<Tensor>("Out");
43+
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
44+
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
45+
46+
// allocate memory on device.
47+
dX->mutable_data<T>(context.GetPlace());
48+
49+
math::SoftmaxGradCUDNNFunctor<T>()(
50+
context.template device_context<platform::CUDADeviceContext>(), Out,
51+
dOut, dX);
52+
}
53+
};
54+
55+
} // namespace operators
56+
} // namespace paddle
57+
58+
namespace ops = paddle::operators;
59+
REGISTER_OP_KERNEL(softmax, CUDNN, ::paddle::platform::CUDAPlace,
60+
ops::SoftmaxCUDNNKernel<float>);
61+
REGISTER_OP_KERNEL(softmax_grad, CUDNN, ::paddle::platform::CUDAPlace,
62+
ops::SoftmaxGradCUDNNKernel<float>);

0 commit comments

Comments
 (0)