Skip to content

Commit 398ebba

Browse files
committed
Merge pull request opencv#10795 from pengli:dnn
2 parents 2a1f46c + c43498c commit 398ebba

File tree

3 files changed

+36
-41
lines changed

3 files changed

+36
-41
lines changed

modules/dnn/src/dnn.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,8 +1386,11 @@ struct Net::Impl
13861386

13871387
if ( preferableTarget == DNN_TARGET_OPENCL )
13881388
{
1389-
nextData = &layers[activData->consumers[0].lid];
1390-
lpNext = LayerPin(activData->consumers[0].lid, 0);
1389+
if ( !activData->consumers.empty() )
1390+
{
1391+
nextData = &layers[activData->consumers[0].lid];
1392+
lpNext = LayerPin(activData->consumers[0].lid, 0);
1393+
}
13911394
}
13921395
}
13931396
}

modules/dnn/src/layers/slice_layer.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ class SliceLayerImpl : public SliceLayer
181181
inputs_.getUMatVector(inputs);
182182
outputs_.getUMatVector(outputs);
183183

184-
if (inputs[0].dims < 4)
184+
if (inputs[0].dims < 4 || (total(shape(outputs[0]), 0, 2) % 4 != 0) ||
185+
(total(shape(outputs[0]), 2) % 4 != 0))
185186
return false;
186187

187188
const UMat& inpMat = inputs[0];
@@ -192,22 +193,19 @@ class SliceLayerImpl : public SliceLayer
192193
int rows = outputs[i].size[2];
193194
int cols = outputs[i].size[3];
194195

195-
int number = (cols % 8 == 0) ? 8 : ((cols % 4 == 0) ? 4 : 1);
196-
String buildopt = format("-DNUM=%d ", number);
197-
String kname = format("slice%d", number);
198-
ocl::Kernel kernel(kname.c_str(), ocl::dnn::slice_oclsrc, buildopt);
199-
size_t global[] = { (size_t)groups * channels, (size_t)rows * cols / number };
196+
ocl::Kernel kernel("slice", ocl::dnn::slice_oclsrc);
197+
size_t local[] = { 128 };
198+
size_t global[] = { (size_t)groups * channels / 4 * local[0] };
200199
int idx = 0;
201200
kernel.set(idx++, ocl::KernelArg::PtrReadOnly(inpMat));
202201
kernel.set(idx++, (int)(inpMat.size[2] * inpMat.size[3]));
203-
kernel.set(idx++, (int)inpMat.size[3]);
204-
kernel.set(idx++, (int)global[0]);
205202
kernel.set(idx++, (int)(rows * cols));
203+
kernel.set(idx++, (int)inpMat.size[3]);
206204
kernel.set(idx++, (int)cols);
207205
kernel.set(idx++, (int)sliceRanges[i][2].start);
208206
kernel.set(idx++, (int)sliceRanges[i][3].start);
209207
kernel.set(idx++, ocl::KernelArg::PtrWriteOnly(outputs[i]));
210-
bool ret = kernel.run(2, global, NULL, false);
208+
bool ret = kernel.run(1, global, local, false);
211209
if (!ret)
212210
return false;
213211
}

modules/dnn/src/opencl/slice.cl

Lines changed: 24 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,44 +44,38 @@
4444
#define Dtype4 float4
4545
#define Dtype8 float8
4646

47-
#if NUM == 8
48-
#define load(src, index) vload8(0, src + index)
49-
#define store(vec, dst, index) vstore8(vec, 0, dst + index)
50-
#define vec_type Dtype8
51-
#define SLICE slice8
52-
#elif NUM == 4
53-
#define load(src, index) vload4(0, src + index)
54-
#define store(vec, dst, index) vstore4(vec, 0, dst + index)
55-
#define vec_type Dtype4
56-
#define SLICE slice4
57-
#elif NUM == 1
58-
#define load(src, index) src[index]
59-
#define store(vec, dst, index) dst[index] = vec
60-
#define vec_type Dtype
61-
#define SLICE slice1
62-
#endif
63-
64-
__kernel void SLICE(__global const Dtype* src,
47+
__kernel void slice(__global const Dtype* src,
6548
const int src_plane_size,
66-
const int src_cols,
67-
const int channels,
6849
const int dst_plane_size,
50+
const int src_cols,
6951
const int dst_cols,
7052
const int row_offset,
7153
const int col_offset,
7254
__global Dtype* dst)
7355
{
74-
int x = get_global_id(0);
75-
int y = get_global_id(1) * NUM;
56+
unsigned int row_gid = get_group_id(0);
57+
unsigned int lid = get_local_id(0);
58+
const __global Dtype *src_read = src + row_gid * 4 * src_plane_size;
59+
__global Dtype *dst_read = dst + row_gid * 4 * dst_plane_size;
60+
Dtype4 a0, a1, a2, a3;
61+
62+
int i = lid;
63+
while( i < dst_plane_size / 4)
64+
{
65+
int row = (4 * i) / dst_cols + row_offset;
66+
int col = (4 * i) % dst_cols + col_offset;
67+
int src_index = row * src_cols + col;
7668

77-
if ((x >= channels) || (y >= dst_plane_size))
78-
return;
69+
a0 = vload4(0, src_read + src_index);
70+
a1 = vload4(0, src_read + src_index + src_plane_size);
71+
a2 = vload4(0, src_read + src_index + 2 * src_plane_size);
72+
a3 = vload4(0, src_read + src_index + 3 * src_plane_size);
7973

80-
int row = y / dst_cols + row_offset;
81-
int col = y % dst_cols + col_offset;
74+
vstore4(a0, i, dst_read);
75+
vstore4(a1, i, dst_read + dst_plane_size);
76+
vstore4(a2, i, dst_read + 2 * dst_plane_size);
77+
vstore4(a3, i, dst_read + 3 * dst_plane_size);
8278

83-
int src_index = x * src_plane_size + row * src_cols + col;
84-
int dst_index = x * dst_plane_size + y;
85-
vec_type val = load(src, src_index);
86-
store(val, dst, dst_index);
79+
i += get_local_size(0);
80+
}
8781
}

0 commit comments

Comments
 (0)