Skip to content

Commit fd7e643

Browse files
authored
Convolution fusion operator. (#14449)
* Convolution fusion operator. * Clean code test=develop
1 parent d7bd036 commit fd7e643

File tree

11 files changed

+530
-41
lines changed

11 files changed

+530
-41
lines changed

cmake/operators.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ function(op_library TARGET)
111111

112112
# Define operators that don't need pybind here.
113113
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
114-
"tensor_array_read_write_op" "tensorrt_engine_op")
114+
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op")
115115
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
116116
set(pybind_flag 1)
117117
endif()

paddle/fluid/operators/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
3434
add_subdirectory(tensorrt)
3535
endif()
3636

37-
register_operators(EXCLUDES warpctc_op)
37+
register_operators(EXCLUDES warpctc_op conv_fusion_op)
3838

3939
# warpctc_cudnn need cudnn 7 above
4040
if (WITH_GPU)
@@ -43,6 +43,8 @@ if (WITH_GPU)
4343
else()
4444
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
4545
endif()
46+
op_library(conv_fusion_op)
47+
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(conv2d_fusion);\n")
4648
else()
4749
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
4850
endif()

paddle/fluid/operators/conv_cudnn_op.cu.cc

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,26 +43,6 @@ using DataLayout = platform::DataLayout;
4343
template <typename T>
4444
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
4545

46-
static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache";
47-
static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache";
48-
static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache";
49-
50-
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
51-
static_cast<size_t>(1024) * 1024 * 1024;
52-
53-
#if CUDNN_VERSION_MIN(6, 0, 5)
54-
static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
55-
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS =
56-
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
57-
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS =
58-
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
59-
#else
60-
// cuDNN v5 has no CUDNN_CONVOLUTION_FWD_ALGO_COUNT etc.
61-
static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7;
62-
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
63-
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
64-
#endif
65-
6646
template <typename T>
6747
class CUDNNConvOpKernel : public framework::OpKernel<T> {
6848
public:

paddle/fluid/operators/conv_cudnn_op_cache.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,31 @@ limitations under the License. */
1717
#include <functional>
1818
#include <unordered_map>
1919
#include <vector>
20+
#include "paddle/fluid/platform/cudnn_helper.h"
2021

2122
namespace paddle {
2223
namespace operators {
2324

25+
static constexpr char kCUDNNFwdAlgoCache[] = "kCUDNNFwdAlgoCache";
26+
static constexpr char kCUDNNBwdDataAlgoCache[] = "kCUDNNBwdDataAlgoCache";
27+
static constexpr char kCUDNNBwdFilterAlgoCache[] = "kCUDNNBwdFilterAlgoCache";
28+
29+
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
30+
static_cast<size_t>(1024) * 1024 * 1024;
31+
32+
#if CUDNN_VERSION_MIN(6, 0, 5)
33+
static constexpr size_t kNUM_CUDNN_FWD_ALGS = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
34+
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS =
35+
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT;
36+
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS =
37+
CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
38+
#else
39+
// cuDNN v5 has no CUDNN_CONVOLUTION_FWD_ALGO_COUNT etc.
40+
static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7;
41+
static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
42+
static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
43+
#endif
44+
2445
template <typename TAlgorithm>
2546
class AlgorithmsCache {
2647
public:
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <string>
16+
#include <vector>
17+
#include "paddle/fluid/operators/conv_op.h"
18+
#ifdef PADDLE_WITH_CUDA
19+
#include "paddle/fluid/platform/cudnn_helper.h"
20+
#endif
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
// This fused conv follows the equation:
26+
// y = act ( alpha1 * conv(x) + alpha2 * z + bias ).
27+
// here, y is Output,
28+
// x is Input,
29+
// z is ResidualData,
30+
// bias is Bias
31+
class Conv2DFusionOpMaker : public Conv2DOpMaker {
32+
protected:
33+
void Apply() override {
34+
AddAttr<std::string>(
35+
"activation",
36+
"The activation type can be 'identity', 'sigmoid', 'relu', 'relu6' "
37+
"'relux' , 'tanh', 'band_pass'")
38+
.SetDefault("relu");
39+
}
40+
};
41+
// TODO(qingqing): add gradient operator for conv2d_fusion
42+
43+
} // namespace operators
44+
} // namespace paddle
45+
46+
namespace ops = paddle::operators;
47+
REGISTER_OPERATOR(conv2d_fusion, ops::ConvOp, ops::Conv2DFusionOpMaker,
48+
ops::ConvOpInferVarType, paddle::framework::EmptyGradOpMaker);
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
17+
#include "paddle/fluid/platform/cudnn_helper.h"
18+
19+
DECLARE_uint64(conv_workspace_size_limit);
20+
DECLARE_bool(cudnn_exhaustive_search);
21+
22+
namespace paddle {
23+
namespace operators {
24+
25+
using Tensor = framework::Tensor;
26+
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
27+
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
28+
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
29+
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
30+
using DataLayout = platform::DataLayout;
31+
template <typename T>
32+
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
33+
34+
template <typename T>
35+
class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
36+
public:
37+
void Compute(const framework::ExecutionContext& ctx) const override {
38+
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
39+
auto* input = ctx.Input<Tensor>("Input");
40+
auto* filter = ctx.Input<Tensor>("Filter");
41+
auto* bias = ctx.Input<Tensor>("Bias");
42+
PADDLE_ENFORCE(bias, "The bias should not be null.");
43+
auto* residual = ctx.Input<Tensor>("ResidualData");
44+
auto* output = ctx.Output<Tensor>("Output");
45+
46+
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
47+
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
48+
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
49+
const std::string activation = ctx.Attr<std::string>("activation");
50+
int groups = ctx.Attr<int>("groups");
51+
int64_t user_workspace_size =
52+
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
53+
bool exhaustive_search =
54+
FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search");
55+
56+
const T* input_data = input->data<T>();
57+
const T* filter_data = filter->data<T>();
58+
const T* bias_data = bias->data<T>();
59+
T* output_data = output->mutable_data<T>(ctx.GetPlace());
60+
const T* residual_data = residual ? residual->data<T>() : output_data;
61+
62+
// ------------------- cudnn descriptors ---------------------
63+
ScopedTensorDescriptor input_desc;
64+
ScopedTensorDescriptor output_desc;
65+
ScopedFilterDescriptor filter_desc;
66+
ScopedTensorDescriptor bias_desc;
67+
ScopedConvolutionDescriptor conv_desc;
68+
ScopedActivationDescriptor act_desc;
69+
DataLayout layout = DataLayout::kNCHW;
70+
if (input->dims().size() == 5) {
71+
layout = DataLayout::kNCDHW;
72+
}
73+
74+
cudnnConvolutionDescriptor_t cudnn_conv_desc =
75+
conv_desc.descriptor<T>(paddings, strides, dilations);
76+
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount(
77+
cudnn_conv_desc, groups));
78+
79+
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
80+
layout, framework::vectorize2int(input->dims()));
81+
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
82+
layout, framework::vectorize2int(output->dims()));
83+
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
84+
layout, framework::vectorize2int(filter->dims()));
85+
// Now only support NCHW
86+
std::vector<int> bias_dim = {1, static_cast<int>(output->dims()[1]), 1, 1};
87+
cudnnTensorDescriptor_t cudnn_bias_desc =
88+
bias_desc.descriptor<T>(layout, bias_dim);
89+
cudnnActivationDescriptor_t cudnn_act_desc =
90+
act_desc.descriptor<T>(activation);
91+
92+
// ------------------- cudnn conv workspace ---------------------
93+
size_t workspace_size_in_bytes; // final workspace to allocate.
94+
size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES;
95+
if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) {
96+
int64_t max_user_size =
97+
std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit),
98+
user_workspace_size);
99+
workspace_size_limit = max_user_size * 1024 * 1024;
100+
}
101+
102+
// ------------------- cudnn conv algorithm ---------------------
103+
cudnnConvolutionFwdAlgo_t algo;
104+
auto handle = dev_ctx.cudnn_handle();
105+
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
106+
107+
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
108+
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
109+
110+
auto x_dims = framework::vectorize(input->dims());
111+
auto f_dims = framework::vectorize(filter->dims());
112+
if (activation == "identity") {
113+
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
114+
// enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
115+
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
116+
} else if (!exhaustive_search) {
117+
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm(
118+
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
119+
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
120+
workspace_size_limit, &algo));
121+
VLOG(3) << "cuDNN forward algo " << algo;
122+
} else {
123+
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr;
124+
if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) {
125+
algo_cache =
126+
ctx.scope()
127+
.FindVar(kCUDNNFwdAlgoCache)
128+
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
129+
} else {
130+
algo_cache =
131+
const_cast<framework::Scope&>(ctx.scope())
132+
.Var(kCUDNNFwdAlgoCache)
133+
->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>();
134+
}
135+
algo = algo_cache->GetAlgorithm(
136+
x_dims, f_dims, strides, paddings, dilations, 0, [&]() {
137+
int returned_algo_count;
138+
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
139+
fwd_perf_stat;
140+
auto cudnn_find_func = [&](void* cudnn_workspace) {
141+
CUDNN_ENFORCE(
142+
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
143+
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
144+
filter_data, cudnn_conv_desc, cudnn_output_desc,
145+
output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
146+
fwd_perf_stat.data(), cudnn_workspace,
147+
workspace_size_limit));
148+
};
149+
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
150+
VLOG(3) << "Perf result: (algo: stat, time, memory)";
151+
for (int i = 0; i < returned_algo_count; ++i) {
152+
const auto& stat = fwd_perf_stat[i];
153+
VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time
154+
<< " " << stat.memory;
155+
}
156+
return fwd_perf_stat[0].algo;
157+
});
158+
VLOG(3) << "choose algo " << algo;
159+
}
160+
161+
CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
162+
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
163+
cudnn_output_desc, algo, &workspace_size_in_bytes));
164+
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
165+
"workspace_size to be allocated exceeds the limit");
166+
167+
// ------------------- cudnn conv+bias+act forward --------------------
168+
ScalingParamType<T> alpha1 = 1.0f;
169+
ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
170+
auto cudnn_func = [&](void* cudnn_workspace) {
171+
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
172+
handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc,
173+
filter_data, cudnn_conv_desc, algo, cudnn_workspace,
174+
workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data,
175+
cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc,
176+
output_data));
177+
};
178+
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
179+
}
180+
};
181+
182+
} // namespace operators
183+
} // namespace paddle
184+
185+
namespace ops = paddle::operators;
186+
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>,
187+
ops::CUDNNConvFusionOpKernel<double>);

paddle/fluid/operators/conv_op.cc

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,17 +225,9 @@ The input(X) size and output(Out) size may be different.
225225
W_{out}= \frac{(W_{in} + 2 * paddings[1] - (dilations[1] * (W_f - 1) + 1))}{strides[1]}+ 1
226226
$$
227227
)DOC");
228+
Apply();
228229
}
229230

230-
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
231-
protected:
232-
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
233-
const override {
234-
return std::unordered_map<std::string, std::string>{
235-
{"Input", /*->*/ "Output"}};
236-
}
237-
};
238-
239231
void Conv3DOpMaker::Make() {
240232
AddInput(
241233
"Input",
@@ -334,6 +326,7 @@ The input(X) size and output(Out) size may be different.
334326
W_{out}= \frac{(W_{in} + 2 * paddings[2] - (dilations[2] * (W_f - 1) + 1))}{ strides[2]}+ 1
335327
$$
336328
)DOC");
329+
Apply();
337330
}
338331

339332
void ConvOpGrad::InferShape(framework::InferShapeContext* ctx) const {

paddle/fluid/operators/conv_op.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <string>
1718
#include <vector>
1819
#include "paddle/fluid/framework/eigen.h"
1920
#include "paddle/fluid/framework/op_registry.h"
@@ -60,12 +61,27 @@ inline bool IsExpand(const std::vector<int64_t>& filter_dim,
6061
// operator implementations can reuse the code.
6162
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
6263
public:
63-
void Make() override;
64+
void Make() final;
65+
66+
protected:
67+
virtual void Apply() {}
6468
};
6569

6670
class Conv3DOpMaker : public framework::OpProtoAndCheckerMaker {
6771
public:
68-
void Make() override;
72+
void Make() final;
73+
74+
protected:
75+
virtual void Apply() {}
76+
};
77+
78+
class ConvOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
79+
protected:
80+
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
81+
const override {
82+
return std::unordered_map<std::string, std::string>{
83+
{"Input", /*->*/ "Output"}};
84+
}
6985
};
7086

7187
class ConvOp : public framework::OperatorWithKernel {

0 commit comments

Comments
 (0)