|
| 1 | +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +
|
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | +
|
| 7 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +
|
| 9 | +Unless required by applicable law or agreed to in writing, software |
| 10 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +See the License for the specific language governing permissions and |
| 13 | +limitations under the License. */ |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/op_registry.h" |
| 16 | +#include "paddle/fluid/operators/conv_cudnn_op_cache.h" |
| 17 | +#include "paddle/fluid/platform/cudnn_helper.h" |
| 18 | + |
| 19 | +DECLARE_uint64(conv_workspace_size_limit); |
| 20 | +DECLARE_bool(cudnn_exhaustive_search); |
| 21 | + |
| 22 | +namespace paddle { |
| 23 | +namespace operators { |
| 24 | + |
| 25 | +using Tensor = framework::Tensor; |
| 26 | +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; |
| 27 | +using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; |
| 28 | +using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; |
| 29 | +using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; |
| 30 | +using DataLayout = platform::DataLayout; |
| 31 | +template <typename T> |
| 32 | +using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType; |
| 33 | + |
| 34 | +template <typename T> |
| 35 | +class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { |
| 36 | + public: |
| 37 | + void Compute(const framework::ExecutionContext& ctx) const override { |
| 38 | + auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); |
| 39 | + auto* input = ctx.Input<Tensor>("Input"); |
| 40 | + auto* filter = ctx.Input<Tensor>("Filter"); |
| 41 | + auto* bias = ctx.Input<Tensor>("Bias"); |
| 42 | + PADDLE_ENFORCE(bias, "The bias should not be null."); |
| 43 | + auto* residual = ctx.Input<Tensor>("ResidualData"); |
| 44 | + auto* output = ctx.Output<Tensor>("Output"); |
| 45 | + |
| 46 | + std::vector<int> strides = ctx.Attr<std::vector<int>>("strides"); |
| 47 | + std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings"); |
| 48 | + std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); |
| 49 | + const std::string activation = ctx.Attr<std::string>("activation"); |
| 50 | + int groups = ctx.Attr<int>("groups"); |
| 51 | + int64_t user_workspace_size = |
| 52 | + static_cast<size_t>(ctx.Attr<int>("workspace_size_MB")); |
| 53 | + bool exhaustive_search = |
| 54 | + FLAGS_cudnn_exhaustive_search || ctx.Attr<bool>("exhaustive_search"); |
| 55 | + |
| 56 | + const T* input_data = input->data<T>(); |
| 57 | + const T* filter_data = filter->data<T>(); |
| 58 | + const T* bias_data = bias->data<T>(); |
| 59 | + T* output_data = output->mutable_data<T>(ctx.GetPlace()); |
| 60 | + const T* residual_data = residual ? residual->data<T>() : output_data; |
| 61 | + |
| 62 | + // ------------------- cudnn descriptors --------------------- |
| 63 | + ScopedTensorDescriptor input_desc; |
| 64 | + ScopedTensorDescriptor output_desc; |
| 65 | + ScopedFilterDescriptor filter_desc; |
| 66 | + ScopedTensorDescriptor bias_desc; |
| 67 | + ScopedConvolutionDescriptor conv_desc; |
| 68 | + ScopedActivationDescriptor act_desc; |
| 69 | + DataLayout layout = DataLayout::kNCHW; |
| 70 | + if (input->dims().size() == 5) { |
| 71 | + layout = DataLayout::kNCDHW; |
| 72 | + } |
| 73 | + |
| 74 | + cudnnConvolutionDescriptor_t cudnn_conv_desc = |
| 75 | + conv_desc.descriptor<T>(paddings, strides, dilations); |
| 76 | + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionGroupCount( |
| 77 | + cudnn_conv_desc, groups)); |
| 78 | + |
| 79 | + cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( |
| 80 | + layout, framework::vectorize2int(input->dims())); |
| 81 | + cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>( |
| 82 | + layout, framework::vectorize2int(output->dims())); |
| 83 | + cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>( |
| 84 | + layout, framework::vectorize2int(filter->dims())); |
| 85 | + // Now only support NCHW |
| 86 | + std::vector<int> bias_dim = {1, static_cast<int>(output->dims()[1]), 1, 1}; |
| 87 | + cudnnTensorDescriptor_t cudnn_bias_desc = |
| 88 | + bias_desc.descriptor<T>(layout, bias_dim); |
| 89 | + cudnnActivationDescriptor_t cudnn_act_desc = |
| 90 | + act_desc.descriptor<T>(activation); |
| 91 | + |
| 92 | + // ------------------- cudnn conv workspace --------------------- |
| 93 | + size_t workspace_size_in_bytes; // final workspace to allocate. |
| 94 | + size_t workspace_size_limit = kCONV_CUDNN_WORKSPACE_LIMIT_BYTES; |
| 95 | + if (FLAGS_conv_workspace_size_limit > 0 || user_workspace_size > 0) { |
| 96 | + int64_t max_user_size = |
| 97 | + std::max(static_cast<int64_t>(FLAGS_conv_workspace_size_limit), |
| 98 | + user_workspace_size); |
| 99 | + workspace_size_limit = max_user_size * 1024 * 1024; |
| 100 | + } |
| 101 | + |
| 102 | + // ------------------- cudnn conv algorithm --------------------- |
| 103 | + cudnnConvolutionFwdAlgo_t algo; |
| 104 | + auto handle = dev_ctx.cudnn_handle(); |
| 105 | + auto workspace_handle = dev_ctx.cudnn_workspace_handle(); |
| 106 | + |
| 107 | + CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( |
| 108 | + cudnn_conv_desc, CUDNN_DEFAULT_MATH)); |
| 109 | + |
| 110 | + auto x_dims = framework::vectorize(input->dims()); |
| 111 | + auto f_dims = framework::vectorize(filter->dims()); |
| 112 | + if (activation == "identity") { |
| 113 | + // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is |
| 114 | + // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. |
| 115 | + algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; |
| 116 | + } else if (!exhaustive_search) { |
| 117 | + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardAlgorithm( |
| 118 | + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, |
| 119 | + cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, |
| 120 | + workspace_size_limit, &algo)); |
| 121 | + VLOG(3) << "cuDNN forward algo " << algo; |
| 122 | + } else { |
| 123 | + AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* algo_cache = nullptr; |
| 124 | + if (ctx.scope().FindVar(kCUDNNFwdAlgoCache)) { |
| 125 | + algo_cache = |
| 126 | + ctx.scope() |
| 127 | + .FindVar(kCUDNNFwdAlgoCache) |
| 128 | + ->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(); |
| 129 | + } else { |
| 130 | + algo_cache = |
| 131 | + const_cast<framework::Scope&>(ctx.scope()) |
| 132 | + .Var(kCUDNNFwdAlgoCache) |
| 133 | + ->GetMutable<AlgorithmsCache<cudnnConvolutionFwdAlgo_t>>(); |
| 134 | + } |
| 135 | + algo = algo_cache->GetAlgorithm( |
| 136 | + x_dims, f_dims, strides, paddings, dilations, 0, [&]() { |
| 137 | + int returned_algo_count; |
| 138 | + std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS> |
| 139 | + fwd_perf_stat; |
| 140 | + auto cudnn_find_func = [&](void* cudnn_workspace) { |
| 141 | + CUDNN_ENFORCE( |
| 142 | + platform::dynload::cudnnFindConvolutionForwardAlgorithmEx( |
| 143 | + handle, cudnn_input_desc, input_data, cudnn_filter_desc, |
| 144 | + filter_data, cudnn_conv_desc, cudnn_output_desc, |
| 145 | + output_data, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, |
| 146 | + fwd_perf_stat.data(), cudnn_workspace, |
| 147 | + workspace_size_limit)); |
| 148 | + }; |
| 149 | + workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit); |
| 150 | + VLOG(3) << "Perf result: (algo: stat, time, memory)"; |
| 151 | + for (int i = 0; i < returned_algo_count; ++i) { |
| 152 | + const auto& stat = fwd_perf_stat[i]; |
| 153 | + VLOG(3) << stat.algo << ": " << stat.status << " " << stat.time |
| 154 | + << " " << stat.memory; |
| 155 | + } |
| 156 | + return fwd_perf_stat[0].algo; |
| 157 | + }); |
| 158 | + VLOG(3) << "choose algo " << algo; |
| 159 | + } |
| 160 | + |
| 161 | + CUDNN_ENFORCE(platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( |
| 162 | + handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, |
| 163 | + cudnn_output_desc, algo, &workspace_size_in_bytes)); |
| 164 | + PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, |
| 165 | + "workspace_size to be allocated exceeds the limit"); |
| 166 | + |
| 167 | + // ------------------- cudnn conv+bias+act forward -------------------- |
| 168 | + ScalingParamType<T> alpha1 = 1.0f; |
| 169 | + ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f; |
| 170 | + auto cudnn_func = [&](void* cudnn_workspace) { |
| 171 | + CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward( |
| 172 | + handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc, |
| 173 | + filter_data, cudnn_conv_desc, algo, cudnn_workspace, |
| 174 | + workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data, |
| 175 | + cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc, |
| 176 | + output_data)); |
| 177 | + }; |
| 178 | + workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); |
| 179 | + } |
| 180 | +}; |
| 181 | + |
| 182 | +} // namespace operators |
| 183 | +} // namespace paddle |
| 184 | + |
| 185 | +namespace ops = paddle::operators; |
| 186 | +REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>, |
| 187 | + ops::CUDNNConvFusionOpKernel<double>); |
0 commit comments