@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
15
#include " paddle/fluid/framework/op_registry.h"
16
+ #include " paddle/fluid/operators/conv_cudnn_helper.h"
16
17
#include " paddle/fluid/operators/conv_cudnn_op_cache.h"
18
+ #include " paddle/fluid/operators/conv_op.h"
19
+ #include " paddle/fluid/operators/math/padding.h"
17
20
#include " paddle/fluid/platform/cudnn_helper.h"
18
21
19
22
DECLARE_int64 (cudnn_exhaustive_search_times);
@@ -41,9 +44,10 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
41
44
auto * input = ctx.Input <Tensor>(" Input" );
42
45
auto * filter = ctx.Input <Tensor>(" Filter" );
43
46
auto * bias = ctx.Input <Tensor>(" Bias" );
44
- PADDLE_ENFORCE (bias, " The bias should not be null." );
47
+ PADDLE_ENFORCE_NOT_NULL (bias, " The bias should not be null." );
45
48
auto * residual = ctx.Input <Tensor>(" ResidualData" );
46
49
auto * output = ctx.Output <Tensor>(" Output" );
50
+ output->mutable_data <T>(ctx.GetPlace ());
47
51
48
52
std::vector<int > strides = ctx.Attr <std::vector<int >>(" strides" );
49
53
std::vector<int > paddings = ctx.Attr <std::vector<int >>(" paddings" );
@@ -55,11 +59,96 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
55
59
bool exhaustive_search =
56
60
FLAGS_cudnn_exhaustive_search || ctx.Attr <bool >(" exhaustive_search" );
57
61
58
- const T* input_data = input->data <T>();
62
+ // const T* input_data = input->data<T>();
59
63
const T* filter_data = filter->data <T>();
60
64
const T* bias_data = bias->data <T>();
61
- T* output_data = output->mutable_data <T>(ctx.GetPlace ());
65
+ // T* output_data = output->mutable_data<T>(ctx.GetPlace());
66
+
67
+ const std::string padding_algorithm =
68
+ ctx.Attr <std::string>(" padding_algorithm" );
69
+ const std::string data_format = ctx.Attr <std::string>(" data_format" );
70
+
71
+ Tensor transformed_input_channel (input->type ());
72
+ Tensor transformed_output (output->type ());
73
+ T* output_data = nullptr ;
74
+
75
+ transformed_input_channel = *input;
76
+ transformed_output = *output;
77
+ output_data = transformed_output.data <T>();
62
78
const T* residual_data = residual ? residual->data <T>() : output_data;
79
+ // update padding and dilation
80
+ auto in_dims = transformed_input_channel.dims ();
81
+ auto filter_dims = filter->dims ();
82
+ framework::DDim in_data_dims;
83
+ in_data_dims = framework::slice_ddim (in_dims, 2 , in_dims.size ());
84
+
85
+ framework::DDim filter_data_dims =
86
+ framework::slice_ddim (filter_dims, 2 , filter_dims.size ());
87
+ std::vector<int > ksize = framework::vectorize<int >(filter_data_dims);
88
+ UpdatePaddingAndDilation (&paddings, &dilations, padding_algorithm,
89
+ in_data_dims, strides, ksize);
90
+
91
+ int data_dim = strides.size (); // 2d or 3d
92
+ bool is_sys_pad = math::IsSymmetricPadding (paddings, data_dim);
93
+
94
+ Tensor transformed_input;
95
+ std::vector<int > padding_common (data_dim, 0 );
96
+ if (!is_sys_pad) {
97
+ std::vector<int > padding_diff (data_dim);
98
+ std::vector<int > new_input_shape_vec (data_dim + 2 );
99
+ new_input_shape_vec[0 ] = transformed_input_channel.dims ()[0 ];
100
+ new_input_shape_vec[1 ] = transformed_input_channel.dims ()[1 ];
101
+
102
+ std::vector<int > input_pad (transformed_input_channel.dims ().size () * 2 ,
103
+ 0 );
104
+ for (size_t i = 0 ; i < data_dim; ++i) {
105
+ padding_diff[i] = std::abs (paddings[2 * i] - paddings[2 * i + 1 ]);
106
+ padding_common[i] = std::min (paddings[2 * i], paddings[2 * i + 1 ]);
107
+ new_input_shape_vec[i + 2 ] =
108
+ transformed_input_channel.dims ()[i + 2 ] + padding_diff[i];
109
+ input_pad[2 * i + 4 ] = paddings[2 * i] - padding_common[i];
110
+ input_pad[2 * i + 4 + 1 ] = paddings[2 * i + 1 ] - padding_common[i];
111
+ }
112
+ framework::DDim new_input_shape (
113
+ framework::make_ddim (new_input_shape_vec));
114
+ transformed_input.Resize (new_input_shape);
115
+ auto & dev_ctx =
116
+ ctx.template device_context <paddle::platform::CUDADeviceContext>();
117
+
118
+ transformed_input =
119
+ ctx.AllocateTmpTensor <T, paddle::platform::CUDADeviceContext>(
120
+ new_input_shape, dev_ctx);
121
+ const int rank = transformed_input_channel.dims ().size ();
122
+ T pad_value (0.0 );
123
+ switch (rank) {
124
+ case 4 : {
125
+ math::PadFunction<paddle::platform::CUDADeviceContext, T, 4 >(
126
+ ctx, input_pad, transformed_input_channel, pad_value,
127
+ &transformed_input);
128
+ } break ;
129
+ case 5 : {
130
+ math::PadFunction<paddle::platform::CUDADeviceContext, T, 5 >(
131
+ ctx, input_pad, transformed_input_channel, pad_value,
132
+ &transformed_input);
133
+ } break ;
134
+ default :
135
+ PADDLE_THROW (" ConvOp only support tensors with 4 or 5 dimensions." );
136
+ }
137
+
138
+ } else {
139
+ transformed_input = transformed_input_channel;
140
+ if (paddings.size () == data_dim) {
141
+ for (size_t i = 0 ; i < data_dim; ++i) {
142
+ padding_common[i] = paddings[i];
143
+ }
144
+ } else {
145
+ for (size_t i = 0 ; i < data_dim; ++i) {
146
+ padding_common[i] = paddings[2 * i];
147
+ }
148
+ }
149
+ }
150
+
151
+ const T* input_data = transformed_input.data <T>();
63
152
64
153
// ------------------- cudnn descriptors ---------------------
65
154
ScopedTensorDescriptor input_desc;
@@ -74,18 +163,19 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
74
163
}
75
164
76
165
cudnnConvolutionDescriptor_t cudnn_conv_desc =
77
- conv_desc.descriptor <T>(paddings , strides, dilations);
166
+ conv_desc.descriptor <T>(padding_common , strides, dilations);
78
167
CUDNN_ENFORCE (platform::dynload::cudnnSetConvolutionGroupCount (
79
168
cudnn_conv_desc, groups));
80
169
81
170
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
82
- layout, framework::vectorize<int >(input-> dims ()));
171
+ layout, framework::vectorize<int >(transformed_input. dims ()));
83
172
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor <T>(
84
- layout, framework::vectorize<int >(output-> dims ()));
173
+ layout, framework::vectorize<int >(transformed_output. dims ()));
85
174
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor <T>(
86
175
layout, framework::vectorize<int >(filter->dims ()));
87
176
// Now only support NCHW
88
- std::vector<int > bias_dim = {1 , static_cast <int >(output->dims ()[1 ]), 1 , 1 };
177
+ std::vector<int > bias_dim = {
178
+ 1 , static_cast <int >(transformed_output.dims ()[1 ]), 1 , 1 };
89
179
cudnnTensorDescriptor_t cudnn_bias_desc =
90
180
bias_desc.descriptor <T>(layout, bias_dim);
91
181
cudnnActivationDescriptor_t cudnn_act_desc =
@@ -109,7 +199,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
109
199
CUDNN_ENFORCE (platform::dynload::cudnnSetConvolutionMathType (
110
200
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
111
201
112
- auto x_dims = framework::vectorize (input-> dims ());
202
+ auto x_dims = framework::vectorize (transformed_input. dims ());
113
203
auto f_dims = framework::vectorize (filter->dims ());
114
204
if (!exhaustive_search) {
115
205
CUDNN_ENFORCE (platform::dynload::cudnnGetConvolutionForwardAlgorithm (
0 commit comments