@@ -132,7 +132,7 @@ void PoolingLayerImpl::maxPooling(Blob &src, Blob &dst, Blob &mask)
132
132
133
133
bool PoolingLayerImpl::maxPooling_ocl (Blob &src, Blob &dst, Blob &mask)
134
134
{
135
- return pooling_ocl (" MaxPoolForward" , src, dst);
135
+ return pooling_ocl (" MaxPoolForward" , src, dst, &mask );
136
136
}
137
137
138
138
void PoolingLayerImpl::avePooling (Blob &src, Blob &dst)
@@ -201,22 +201,36 @@ bool PoolingLayerImpl::pooling_ocl(const char *kname, const Blob &src, Blob &dst
201
201
{
202
202
const UMat &srcMat = src.umatRefConst ();
203
203
UMat &dstMat = dst.umatRef ();
204
- UMat* indexesMat = mask == NULL ? NULL : &dst.umatRef ();
204
+ UMat *maskUMat = mask == NULL ? NULL : &mask->umatRef ();
205
+ CV_Assert (maskUMat == NULL || maskUMat->type () == CV_32FC1); // FIXIT CV_32SC1
206
+ CV_Assert (maskUMat == NULL || maskUMat->offset == 0 );
205
207
206
208
CV_Assert (srcMat.offset == 0 && dstMat.offset == 0 );
207
209
208
- ocl::Kernel ker (kname, ocl::dnn::pooling_oclsrc, String (" -DT=" ) + ocl::typeToStr (src.type ()));
210
+ ocl::Kernel ker (kname, ocl::dnn::pooling_oclsrc,
211
+ cv::format (" -DT=%s%s" , ocl::typeToStr (src.type ()), maskUMat ? " -DMASK=1" : " " ));
209
212
if (ker.empty ())
210
213
return false ;
211
214
212
215
BlobShape s = src.shape ();
213
216
size_t nthreads = dst.total ();
214
- ker.args ((int )nthreads,
217
+ if (maskUMat)
218
+ {
219
+ ker.args ((int )nthreads,
215
220
ocl::KernelArg::PtrReadOnly (srcMat), s[0 ], s[1 ], s[2 ], s[3 ],
216
221
out.height , out.width , kernel.height , kernel.width ,
217
222
stride.height , stride.width , pad.height , pad.width ,
218
223
ocl::KernelArg::PtrWriteOnly (dstMat),
219
- ocl::KernelArg (ocl::KernelArg::PTR_ONLY + ocl::KernelArg::WRITE_ONLY, indexesMat));
224
+ ocl::KernelArg::PtrWriteOnly (*maskUMat));
225
+ }
226
+ else
227
+ {
228
+ ker.args ((int )nthreads,
229
+ ocl::KernelArg::PtrReadOnly (srcMat), s[0 ], s[1 ], s[2 ], s[3 ],
230
+ out.height , out.width , kernel.height , kernel.width ,
231
+ stride.height , stride.width , pad.height , pad.width ,
232
+ ocl::KernelArg::PtrWriteOnly (dstMat));
233
+ }
220
234
221
235
size_t wgSize = ocl::Device::getDefault ().maxWorkGroupSize ();
222
236
if (!ker.run (1 , &nthreads, &wgSize, true ))
0 commit comments