Skip to content

Commit 31fd7e1

Browse files
committed
Merge pull request #998 from alalek:dnn_fix_ocl_pooling
2 parents 5d9808b + e6550fc commit 31fd7e1

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

modules/dnn/src/layers/pooling_layer.cpp

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ void PoolingLayerImpl::maxPooling(Blob &src, Blob &dst, Blob &mask)
132132

133133
bool PoolingLayerImpl::maxPooling_ocl(Blob &src, Blob &dst, Blob &mask)
134134
{
135-
return pooling_ocl("MaxPoolForward", src, dst);
135+
return pooling_ocl("MaxPoolForward", src, dst, &mask);
136136
}
137137

138138
void PoolingLayerImpl::avePooling(Blob &src, Blob &dst)
@@ -201,22 +201,36 @@ bool PoolingLayerImpl::pooling_ocl(const char *kname, const Blob &src, Blob &dst
201201
{
202202
const UMat &srcMat = src.umatRefConst();
203203
UMat &dstMat = dst.umatRef();
204-
UMat* indexesMat = mask == NULL ? NULL : &dst.umatRef();
204+
UMat *maskUMat = mask == NULL ? NULL : &mask->umatRef();
205+
CV_Assert(maskUMat == NULL || maskUMat->type() == CV_32FC1); // FIXIT CV_32SC1
206+
CV_Assert(maskUMat == NULL || maskUMat->offset == 0);
205207

206208
CV_Assert(srcMat.offset == 0 && dstMat.offset == 0);
207209

208-
ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc, String("-DT=") + ocl::typeToStr(src.type()));
210+
ocl::Kernel ker(kname, ocl::dnn::pooling_oclsrc,
211+
cv::format("-DT=%s%s", ocl::typeToStr(src.type()), maskUMat ? " -DMASK=1" : ""));
209212
if (ker.empty())
210213
return false;
211214

212215
BlobShape s = src.shape();
213216
size_t nthreads = dst.total();
214-
ker.args((int)nthreads,
217+
if (maskUMat)
218+
{
219+
ker.args((int)nthreads,
215220
ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3],
216221
out.height, out.width, kernel.height, kernel.width,
217222
stride.height, stride.width, pad.height, pad.width,
218223
ocl::KernelArg::PtrWriteOnly(dstMat),
219-
ocl::KernelArg(ocl::KernelArg::PTR_ONLY + ocl::KernelArg::WRITE_ONLY, indexesMat));
224+
ocl::KernelArg::PtrWriteOnly(*maskUMat));
225+
}
226+
else
227+
{
228+
ker.args((int)nthreads,
229+
ocl::KernelArg::PtrReadOnly(srcMat), s[0], s[1], s[2], s[3],
230+
out.height, out.width, kernel.height, kernel.width,
231+
stride.height, stride.width, pad.height, pad.width,
232+
ocl::KernelArg::PtrWriteOnly(dstMat));
233+
}
220234

221235
size_t wgSize = ocl::Device::getDefault().maxWorkGroupSize();
222236
if (!ker.run(1, &nthreads, &wgSize, true))

modules/dnn/src/opencl/pooling.cl

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,16 @@
2424
* POSSIBILITY OF SUCH DAMAGE.
2525
**************************************************************************************/
2626

27-
__kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w, __global T* top_data, __global int* mask
28-
) {
27+
__kernel void MaxPoolForward(const int nthreads,
28+
__global T* bottom_data, const int num, const int channels, const int height, const int width,
29+
const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w,
30+
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
31+
__global T* top_data
32+
#ifdef MASK
33+
, __global float* mask
34+
#endif
35+
)
36+
{
2937
int index = get_global_id(0);
3038
int tmp = get_global_size(0);
3139
for(index; index < nthreads; index += tmp) {
@@ -51,15 +59,25 @@ __kernel void MaxPoolForward(const int nthreads, __global T* bottom_data, const
5159
}
5260
}
5361
}
62+
5463
top_data[index] = maxval;
5564

56-
if (mask) {
57-
mask[index] = maxidx;
58-
}
65+
#ifdef MASK
66+
mask[index] = maxidx;
67+
#endif
5968
}
6069
}
6170

62-
__kernel void AvePoolForward(const int nthreads, __global T* bottom_data, const int num, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_h, const int pad_w,__global T* top_data) {
71+
__kernel void AvePoolForward(const int nthreads,
72+
__global T* bottom_data, const int num, const int channels, const int height, const int width,
73+
const int pooled_height, const int pooled_width, const int kernel_h, const int kernel_w,
74+
const int stride_h, const int stride_w, const int pad_h, const int pad_w,
75+
__global T* top_data
76+
#ifdef MASK
77+
, __global float* mask // NOT USED
78+
#endif
79+
)
80+
{
6381
int index = get_global_id(0);
6482
int tmp = get_global_size(0);
6583
for(index; index < nthreads; index+=tmp) {

0 commit comments

Comments
 (0)