|
44 | 44 | #define Dtype4 float4
|
45 | 45 | #define Dtype8 float8
|
46 | 46 |
|
47 |
| -#if NUM == 8 |
48 |
| - #define load(src, index) vload8(0, src + index) |
49 |
| - #define store(vec, dst, index) vstore8(vec, 0, dst + index) |
50 |
| - #define vec_type Dtype8 |
51 |
| - #define SLICE slice8 |
52 |
| -#elif NUM == 4 |
53 |
| - #define load(src, index) vload4(0, src + index) |
54 |
| - #define store(vec, dst, index) vstore4(vec, 0, dst + index) |
55 |
| - #define vec_type Dtype4 |
56 |
| - #define SLICE slice4 |
57 |
| -#elif NUM == 1 |
58 |
| - #define load(src, index) src[index] |
59 |
| - #define store(vec, dst, index) dst[index] = vec |
60 |
| - #define vec_type Dtype |
61 |
| - #define SLICE slice1 |
62 |
| -#endif |
63 |
| - |
64 |
| -__kernel void SLICE(__global const Dtype* src, |
| 47 | +__kernel void slice(__global const Dtype* src, |
65 | 48 | const int src_plane_size,
|
66 |
| - const int src_cols, |
67 |
| - const int channels, |
68 | 49 | const int dst_plane_size,
|
| 50 | + const int src_cols, |
69 | 51 | const int dst_cols,
|
70 | 52 | const int row_offset,
|
71 | 53 | const int col_offset,
|
72 | 54 | __global Dtype* dst)
|
73 | 55 | {
|
74 |
| - int x = get_global_id(0); |
75 |
| - int y = get_global_id(1) * NUM; |
| 56 | + unsigned int row_gid = get_group_id(0); |
| 57 | + unsigned int lid = get_local_id(0); |
| 58 | + const __global Dtype *src_read = src + row_gid * 4 * src_plane_size; |
| 59 | + __global Dtype *dst_read = dst + row_gid * 4 * dst_plane_size; |
| 60 | + Dtype4 a0, a1, a2, a3; |
| 61 | + |
| 62 | + int i = lid; |
| 63 | + while( i < dst_plane_size / 4) |
| 64 | + { |
| 65 | + int row = (4 * i) / dst_cols + row_offset; |
| 66 | + int col = (4 * i) % dst_cols + col_offset; |
| 67 | + int src_index = row * src_cols + col; |
76 | 68 |
|
77 |
| - if ((x >= channels) || (y >= dst_plane_size)) |
78 |
| - return; |
| 69 | + a0 = vload4(0, src_read + src_index); |
| 70 | + a1 = vload4(0, src_read + src_index + src_plane_size); |
| 71 | + a2 = vload4(0, src_read + src_index + 2 * src_plane_size); |
| 72 | + a3 = vload4(0, src_read + src_index + 3 * src_plane_size); |
79 | 73 |
|
80 |
| - int row = y / dst_cols + row_offset; |
81 |
| - int col = y % dst_cols + col_offset; |
| 74 | + vstore4(a0, i, dst_read); |
| 75 | + vstore4(a1, i, dst_read + dst_plane_size); |
| 76 | + vstore4(a2, i, dst_read + 2 * dst_plane_size); |
| 77 | + vstore4(a3, i, dst_read + 3 * dst_plane_size); |
82 | 78 |
|
83 |
| - int src_index = x * src_plane_size + row * src_cols + col; |
84 |
| - int dst_index = x * dst_plane_size + y; |
85 |
| - vec_type val = load(src, src_index); |
86 |
| - store(val, dst, dst_index); |
| 79 | + i += get_local_size(0); |
| 80 | + } |
87 | 81 | }
|
0 commit comments