Skip to content

Commit 5f2e837

Browse files
cjldqingqing01
authored andcommitted
optimize depthwise conv by register memory (#13778)
* optimize depthwise conv by register memory * test=develop
1 parent 5428cb9 commit 5f2e837

File tree

1 file changed

+210
-65
lines changed

1 file changed

+210
-65
lines changed

paddle/fluid/operators/math/depthwise_conv.cu

Lines changed: 210 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,20 @@ __forceinline__ __device__ unsigned warp_id() {
4646
return ret;
4747
}
4848

49+
#define ARG_DEFINE_KernelDepthwiseConv \
50+
const T *const input_data, const T *const filter_data, const int batch_size, \
51+
const int output_channels, const int output_height, \
52+
const int output_width, const int input_channels, \
53+
const int input_height, const int input_width, \
54+
const int filter_multiplier, const int filter_height, \
55+
const int filter_width, const int stride_height, const int stride_width, \
56+
const int padding_height, const int padding_width, \
57+
const int dilate_height, const int dilate_width, T *const output_data
58+
4959
// A Cuda kernel to compute the depthwise convolution forward pass
5060
// in NCHW format.
5161
template <typename T>
52-
__device__ __inline__ void KernelDepthwiseConv(
53-
const T* const input_data, const T* const filter_data, const int batch_size,
54-
const int output_channels, const int output_height, const int output_width,
55-
const int input_channels, const int input_height, const int input_width,
56-
const int filter_multiplier, const int filter_height,
57-
const int filter_width, const int stride_height, const int stride_width,
58-
const int padding_height, const int padding_width, const int dilate_height,
59-
const int dilate_width, T* const output_data) {
62+
__device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
6063
for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
6164
for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
6265
const int batch = blockIdx.y;
@@ -97,42 +100,105 @@ __device__ __inline__ void KernelDepthwiseConv(
97100
}
98101
}
99102

100-
template <typename T, int c_filter_multiplier, int c_stride>
101-
__global__ void KernelDepthwiseConvSp(
102-
const T* const input_data, const T* const filter_data, const int batch_size,
103-
const int output_channels, const int output_height, const int output_width,
104-
const int input_channels, const int input_height, const int input_width,
105-
const int filter_multiplier, const int filter_height,
106-
const int filter_width, const int stride_height, const int stride_width,
107-
const int padding_height, const int padding_width, const int dilate_height,
108-
const int dilate_width, T* const output_data) {
109-
if (c_filter_multiplier == 0)
110-
KernelDepthwiseConv<T>(input_data, filter_data, batch_size, output_channels,
111-
output_height, output_width, input_channels,
112-
input_height, input_width, filter_multiplier,
113-
filter_height, filter_width, stride_height,
114-
stride_width, padding_height, padding_width,
115-
dilate_height, dilate_width, output_data);
103+
template <typename T, int c_filter>
104+
__device__ __inline__ void KernelDepthwiseConvCFilter(
105+
ARG_DEFINE_KernelDepthwiseConv) {
106+
const int kWeghtSize = c_filter * c_filter;
107+
T r_weight[kWeghtSize];
108+
const int batch = blockIdx.y;
109+
const int c_out = blockIdx.x;
110+
const T* weight = filter_data + c_out * c_filter * c_filter;
111+
for (int i = 0; i < c_filter * c_filter; i++) r_weight[i] = weight[i];
116112

117-
else
118-
KernelDepthwiseConv<T>(input_data, filter_data, batch_size, output_channels,
119-
output_height, output_width, input_channels,
120-
input_height, input_width, c_filter_multiplier,
121-
filter_height, filter_height, c_stride, c_stride,
122-
padding_height, padding_width, dilate_height,
123-
dilate_width, output_data);
113+
for (int w_out = threadIdx.x; w_out < output_width; w_out += blockDim.x) {
114+
for (int h_out = threadIdx.y; h_out < output_height; h_out += blockDim.y) {
115+
const int batch = blockIdx.y;
116+
const int c_out = blockIdx.x;
117+
118+
const int c_in = c_out / filter_multiplier;
119+
T value = 0;
120+
const int h_in_start = -padding_height + h_out * stride_height;
121+
const int w_in_start = -padding_width + w_out * stride_width;
122+
const int h_in_end = h_in_start + c_filter * dilate_height;
123+
const int w_in_end = w_in_start + c_filter * dilate_width;
124+
125+
const int in_offset =
126+
((batch * input_channels + c_in) * input_height) * input_width;
127+
128+
const int h_end = h_in_end < input_height ? h_in_end : input_height;
129+
const int w_end = w_in_end < input_width ? w_in_end : input_width;
130+
const int h_start = h_in_start > 0 ? h_in_start : 0;
131+
const int w_start = w_in_start > 0 ? w_in_start : 0;
132+
133+
for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
134+
h_in += dilate_height, h_f++) {
135+
for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
136+
w_in += dilate_width, w_f++) {
137+
if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
138+
w_in < input_width) {
139+
const int offset = in_offset + h_in * input_width + w_in;
140+
value += r_weight[h_f * c_filter + w_f] * input_data[offset];
141+
}
142+
}
143+
}
144+
int index =
145+
((batch * gridDim.x + c_out) * output_height + h_out) * output_width +
146+
w_out;
147+
output_data[index] = value;
148+
}
149+
}
150+
}
151+
152+
template <typename T, int c_filter_multiplier, int c_stride, int c_filter>
153+
__global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
154+
if (c_filter_multiplier == 0) {
155+
if (c_filter == -1)
156+
KernelDepthwiseConv<T>(
157+
input_data, filter_data, batch_size, output_channels, output_height,
158+
output_width, input_channels, input_height, input_width,
159+
filter_multiplier, filter_height, filter_width, stride_height,
160+
stride_width, padding_height, padding_width, dilate_height,
161+
dilate_width, output_data);
162+
else
163+
KernelDepthwiseConvCFilter<T, c_filter>(
164+
input_data, filter_data, batch_size, output_channels, output_height,
165+
output_width, input_channels, input_height, input_width,
166+
filter_multiplier, filter_height, filter_width, stride_height,
167+
stride_width, padding_height, padding_width, dilate_height,
168+
dilate_width, output_data);
169+
} else {
170+
if (c_filter == -1)
171+
KernelDepthwiseConv<T>(input_data, filter_data, batch_size,
172+
output_channels, output_height, output_width,
173+
input_channels, input_height, input_width,
174+
c_filter_multiplier, filter_height, filter_height,
175+
c_stride, c_stride, padding_height, padding_width,
176+
dilate_height, dilate_width, output_data);
177+
else
178+
KernelDepthwiseConvCFilter<T, c_filter>(
179+
input_data, filter_data, batch_size, output_channels, output_height,
180+
output_width, input_channels, input_height, input_width,
181+
c_filter_multiplier, filter_height, filter_height, c_stride, c_stride,
182+
padding_height, padding_width, dilate_height, dilate_width,
183+
output_data);
184+
}
124185
}
125186

126187
// CUDA kernel to compute the depthwise convolution backprop w.r.t input.
188+
#define ARG_DEFINE_KernelDepthwiseConvInputGrad \
189+
const T *const output_grad_data, const T *const filter_data, \
190+
const int batch_size, const int output_channels, \
191+
const int output_height, const int output_width, \
192+
const int input_channels, const int input_height, const int input_width, \
193+
const int filter_multiplier, const int filter_height, \
194+
const int filter_width, const int stride_height, const int stride_width, \
195+
const int padding_height, const int padding_width, \
196+
const int dilate_height, const int dilate_width, \
197+
T *const input_grad_data
198+
127199
template <typename T>
128200
__device__ __inline__ void KernelDepthwiseConvInputGrad(
129-
const T* const output_grad_data, const T* const filter_data,
130-
const int batch_size, const int output_channels, const int output_height,
131-
const int output_width, const int input_channels, const int input_height,
132-
const int input_width, const int filter_multiplier, const int filter_height,
133-
const int filter_width, const int stride_height, const int stride_width,
134-
const int padding_height, const int padding_width, const int dilate_height,
135-
const int dilate_width, T* const input_grad_data) {
201+
ARG_DEFINE_KernelDepthwiseConvInputGrad) {
136202
for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
137203
for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
138204
const int batch = blockIdx.y;
@@ -184,29 +250,88 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
184250
}
185251
}
186252

187-
template <typename T, int c_filter_multiplier, int c_stride>
253+
template <typename T, int c_filter, int c_filter_multiplier>
254+
__device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
255+
ARG_DEFINE_KernelDepthwiseConvInputGrad) {
256+
const int kWeghtSize = c_filter * c_filter * c_filter_multiplier + 1;
257+
T r_weight[kWeghtSize];
258+
const int batch = blockIdx.y;
259+
const int c_in = blockIdx.x;
260+
261+
for (int c_i = 0; c_i < filter_multiplier; c_i++) {
262+
int c_out = c_in * filter_multiplier + c_i;
263+
const T* weight = filter_data + c_out * c_filter * c_filter;
264+
for (int i = 0; i < c_filter * c_filter; i++)
265+
r_weight[i + c_i * c_filter * c_filter] =
266+
weight[c_filter * c_filter - i - 1];
267+
}
268+
269+
for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
270+
for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
271+
const int batch = blockIdx.y;
272+
const int c_in = blockIdx.x;
273+
274+
int h_out_start = h_in - (c_filter - 1) * dilate_height + padding_height;
275+
276+
int w_out_start = w_in - (c_filter - 1) * dilate_width + padding_width;
277+
278+
T value = 0;
279+
280+
for (int c_i = 0; c_i < filter_multiplier; c_i++) {
281+
int c_out = c_in * filter_multiplier + c_i;
282+
for (int h_out = h_out_start, h_f = 0; h_f < c_filter;
283+
h_out += dilate_height, h_f++) {
284+
for (int w_out = w_out_start, w_f = 0; w_f < c_filter;
285+
w_out += dilate_width, w_f++) {
286+
int s_h_out = h_out / stride_height;
287+
int s_w_out = w_out / stride_width;
288+
if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
289+
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
290+
s_w_out < output_width) {
291+
const int output_grad_offset =
292+
((batch * output_channels + c_out) * output_height +
293+
s_h_out) *
294+
output_width +
295+
s_w_out;
296+
value +=
297+
output_grad_data[output_grad_offset] *
298+
r_weight[h_f * c_filter + w_f + c_i * c_filter * c_filter];
299+
}
300+
}
301+
}
302+
}
303+
int index =
304+
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
305+
w_in;
306+
input_grad_data[index] = value;
307+
}
308+
}
309+
}
310+
311+
template <typename T, int c_filter_multiplier, int c_stride, int c_filter>
188312
__global__ void KernelDepthwiseConvInputGradSp(
189-
const T* const output_grad_data, const T* const filter_data,
190-
const int batch_size, const int output_channels, const int output_height,
191-
const int output_width, const int input_channels, const int input_height,
192-
const int input_width, const int filter_multiplier, const int filter_height,
193-
const int filter_width, const int stride_height, const int stride_width,
194-
const int padding_height, const int padding_width, const int dilate_height,
195-
const int dilate_width, T* const input_grad_data) {
313+
ARG_DEFINE_KernelDepthwiseConvInputGrad) {
196314
if (c_filter_multiplier == 0)
197315
KernelDepthwiseConvInputGrad<T>(
198316
output_grad_data, filter_data, batch_size, output_channels,
199317
output_height, output_width, input_channels, input_height, input_width,
200318
filter_multiplier, filter_height, filter_width, stride_height,
201319
stride_width, padding_height, padding_width, dilate_height,
202320
dilate_width, input_grad_data);
203-
else
321+
else if (c_filter == -1)
204322
KernelDepthwiseConvInputGrad<T>(
205323
output_grad_data, filter_data, batch_size, output_channels,
206324
output_height, output_width, input_channels, input_height, input_width,
207325
c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
208326
padding_height, padding_width, dilate_height, dilate_width,
209327
input_grad_data);
328+
else
329+
KernelDepthwiseConvInputGradCFilter<T, c_filter, c_filter_multiplier>(
330+
output_grad_data, filter_data, batch_size, output_channels,
331+
output_height, output_width, input_channels, input_height, input_width,
332+
c_filter_multiplier, filter_height, filter_width, c_stride, c_stride,
333+
padding_height, padding_width, dilate_height, dilate_width,
334+
input_grad_data);
210335
}
211336

212337
// Cuda kernel to compute the depthwise convolution backprop w.r.t. filter.
@@ -325,24 +450,32 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
325450
dim3 threads(std::min(output_width, thread), blocks, 1);
326451
dim3 grid(output_channels, batch_size, 1);
327452
int filter_multiplier = output_channels / input_channels;
328-
#define check_case(c_filter_multiplier, c_stride) \
453+
#define check_case(c_filter_multiplier, c_stride, c_filter) \
329454
if (c_filter_multiplier == 0 || \
330455
filter_multiplier == c_filter_multiplier && \
331-
stride_height == stride_width && stride_height == c_stride) { \
332-
KernelDepthwiseConvSp<T, c_filter_multiplier, \
333-
c_stride><<<grid, threads, 0, context.stream()>>>( \
456+
stride_height == stride_width && stride_height == c_stride && \
457+
(ksize_height == ksize_width && ksize_height == c_filter || \
458+
c_filter == -1)) { \
459+
KernelDepthwiseConvSp<T, c_filter_multiplier, c_stride, \
460+
c_filter><<<grid, threads, 0, context.stream()>>>( \
334461
input_data, filter_data, batch_size, output_channels, output_height, \
335462
output_width, input_channels, input_height, input_width, \
336463
filter_multiplier, ksize_height, ksize_width, stride_height, \
337464
stride_width, padding_height, padding_width, dilate_height, \
338465
dilate_width, output_data); \
339466
return; \
340467
}
341-
check_case(1, 1);
342-
check_case(1, 2);
343-
// NOTE(liangdun): 0,0 for other case
344-
// add other case if needed, e.g. check_case(2^n,1)
345-
check_case(0, 0);
468+
check_case(1, 1, 3);
469+
check_case(1, 1, 5);
470+
check_case(1, 1, -1);
471+
check_case(1, 2, 3);
472+
check_case(1, 2, 5);
473+
check_case(1, 2, -1);
474+
check_case(0, 0, 3);
475+
check_case(0, 0, 5);
476+
check_case(0, 0, -1);
477+
// NOTE(liangdun): 0,0 for other case
478+
// add other case if needed, e.g. check_case(2^n,1)
346479
#undef check_case
347480
}
348481
};
@@ -384,25 +517,37 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T> {
384517
dim3 grid(input_channels, batch_size, 1);
385518
int filter_multiplier = output_channels / input_channels;
386519

387-
#define check_case(c_filter_multiplier, c_stride) \
520+
#define check_case(c_filter_multiplier, c_stride, c_filter) \
388521
if (c_filter_multiplier == 0 || \
389522
filter_multiplier == c_filter_multiplier && \
390-
stride_height == stride_width && stride_height == c_stride) { \
523+
stride_height == stride_width && stride_height == c_stride && \
524+
(ksize_height == ksize_width && ksize_height == c_filter || \
525+
c_filter == -1)) { \
391526
KernelDepthwiseConvInputGradSp< \
392-
T, c_filter_multiplier, \
393-
c_stride><<<grid, threads, 0, context.stream()>>>( \
527+
T, c_filter_multiplier, c_stride, \
528+
c_filter><<<grid, threads, 0, context.stream()>>>( \
394529
output_grad_data, filter_data, batch_size, output_channels, \
395530
output_height, output_width, input_channels, input_height, \
396531
input_width, filter_multiplier, ksize_height, ksize_width, \
397532
stride_height, stride_width, padding_height, padding_width, \
398533
dilate_height, dilate_width, input_grad_data); \
399534
return; \
400535
}
401-
check_case(1, 1);
402-
check_case(1, 2);
403-
// NOTE(liangdun): 0,0 for other case
404-
// add other case if needed, e.g. check_case(2^n,1)
405-
check_case(0, 0);
536+
check_case(1, 1, 3);
537+
check_case(1, 1, 5);
538+
check_case(1, 1, -1);
539+
check_case(1, 2, 3);
540+
check_case(1, 2, 5);
541+
check_case(1, 2, -1);
542+
check_case(2, 1, 3);
543+
check_case(2, 1, 5);
544+
check_case(2, 1, -1);
545+
check_case(2, 2, 3);
546+
check_case(2, 2, 5);
547+
check_case(2, 2, -1);
548+
check_case(0, 0, -1);
549+
// NOTE(liangdun): 0,0 for other case
550+
// add other case if needed, e.g. check_case(2^n,1)
406551
#undef check_case
407552
}
408553
};

0 commit comments

Comments
 (0)