|
1 | 1 | #include "pad_reflect_1d.cuh" |
2 | 2 |
|
3 | | -static __global__ void pad_reflect_1d_kernel_f32( |
4 | | - const void * __restrict__ src0, |
5 | | - void * __restrict__ dst, |
6 | | - const int64_t ne0, |
7 | | - const int64_t ne00, |
8 | | - const int64_t ne01, |
9 | | - const uint3 ne01_packed, |
10 | | - const int64_t ne02, |
11 | | - const int64_t ne03, |
12 | | - const int64_t nb00, |
13 | | - const int64_t nb01, |
14 | | - const int64_t nb02, |
15 | | - const int64_t nb03, |
16 | | - const int64_t nb0, |
17 | | - const int64_t nb1, |
18 | | - const int64_t nb2, |
19 | | - const int64_t nb3, |
20 | | - const int p0, |
21 | | - const int p1) { |
22 | | - |
| 3 | +static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void |
| 4 | + pad_reflect_1d_kernel_f32( |
| 5 | + const void * __restrict__ src0, |
| 6 | + void * __restrict__ dst, |
| 7 | + const int64_t ne0, |
| 8 | + const int64_t ne00, |
| 9 | + const uint3 ne01, |
| 10 | + const int64_t ne02, |
| 11 | + const int64_t ne03, |
| 12 | + const int64_t nb00, |
| 13 | + const int64_t nb01, |
| 14 | + const int64_t nb02, |
| 15 | + const int64_t nb03, |
| 16 | + const int64_t nb0, |
| 17 | + const int64_t nb1, |
| 18 | + const int64_t nb2, |
| 19 | + const int64_t nb3, |
| 20 | + const int p0, |
| 21 | + const int p1) { |
23 | 22 | const int64_t i3 = blockIdx.z; |
24 | 23 | const int64_t i2 = blockIdx.y; |
25 | 24 |
|
26 | | - const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01_packed); |
27 | | - const int64_t tile1 = div_mod_packed.y; // i1 |
28 | | - const int64_t tile0 = div_mod_packed.x; // nth i0 tile |
29 | | - const int64_t i1 = tile1; |
30 | | - const int64_t i0 = threadIdx.x + tile0 * blockDim.x; |
31 | | - if ( i0 >= ne0 || i1 >= ne01 || i2 >= ne02 || i3 >= ne03 ) { |
| 25 | + const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01); |
| 26 | + const int64_t tile1 = div_mod_packed.y; // i1 |
| 27 | + const int64_t tile0 = div_mod_packed.x; // nth i0 tile |
| 28 | + const int64_t i1 = tile1; |
| 29 | + const int64_t i0 = threadIdx.x + tile0 * blockDim.x; |
| 30 | + |
| 31 | + // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh) |
| 32 | + if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) { |
32 | 33 | return; |
33 | 34 | } |
34 | 35 |
|
35 | | - const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01; |
36 | | - char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1; |
| 36 | + const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; |
| 37 | + char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1; |
37 | 38 |
|
38 | | - float value; |
| 39 | + float value; |
39 | 40 | const int64_t j = i0 - p0; |
40 | 41 |
|
41 | | - if ( j<0 ) {// i0<p0 |
| 42 | + if (j < 0) { // i0<p0 |
42 | 43 | // Left padding - reflect |
43 | | - value = *(const float *)(src0_ptr - j * nb00); |
44 | | - } else if ( j < ne00 ) { //i0 < ne0 - p1 |
| 44 | + value = *(const float *) (src0_ptr - j * nb00); |
| 45 | + } else if (j < ne00) { //i0 < ne0 - p1 |
45 | 46 | // Middle - copy |
46 | | - value = *(const float *)(src0_ptr + j * nb00); |
47 | | - } else { |
| 47 | + value = *(const float *) (src0_ptr + j * nb00); |
| 48 | + } else { |
48 | 49 | // Right padding - reflect |
49 | 50 | const int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1; |
50 | | - value = *(const float *)(src0_ptr + src_idx * nb00); |
| 51 | + value = *(const float *) (src0_ptr + src_idx * nb00); |
51 | 52 | } |
52 | | - *(float *)(dst_ptr + i0 * nb0) = value; |
| 53 | + *(float *) (dst_ptr + i0 * nb0) = value; |
53 | 54 | } |
54 | 55 |
|
55 | 56 | void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
56 | | - const ggml_tensor * src0 = dst->src[0]; |
57 | | - cudaStream_t stream = ctx.stream(); |
| 57 | + const ggml_tensor * src0 = dst->src[0]; |
| 58 | + cudaStream_t stream = ctx.stream(); |
58 | 59 |
|
59 | 60 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
60 | 61 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
61 | 62 |
|
62 | 63 | const int32_t * opts = (const int32_t *) dst->op_params; |
63 | | - const int p0 = opts[0]; |
64 | | - const int p1 = opts[1]; |
| 64 | + const int p0 = opts[0]; |
| 65 | + const int p1 = opts[1]; |
65 | 66 |
|
66 | | - const int64_t ne00 = src0->ne[0]; |
67 | | - const int64_t ne01 = src0->ne[1]; |
| 67 | + const int64_t ne00 = src0->ne[0]; |
| 68 | + const int64_t ne01 = src0->ne[1]; |
68 | 69 | const uint3 ne01_packed = init_fastdiv_values(ne01); |
69 | | - const int64_t ne02 = src0->ne[2]; |
70 | | - const int64_t ne03 = src0->ne[3]; |
| 70 | + const int64_t ne02 = src0->ne[2]; |
| 71 | + const int64_t ne03 = src0->ne[3]; |
71 | 72 |
|
72 | 73 | const int64_t ne0 = dst->ne[0]; |
73 | 74 |
|
74 | 75 | // sanity: padded length matches |
75 | 76 | GGML_ASSERT(ne0 == ne00 + p0 + p1); |
76 | 77 |
|
77 | | - constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x) |
78 | | - const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0 |
| 78 | + constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x) |
| 79 | + const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0 |
79 | 80 | // grid.x covers i1 and all tiles of i0: [ne01 * tiles0] |
80 | 81 | // grid.y covers i2: [ne02] |
81 | 82 | // grid.z covers i3: [ne03] |
82 | | - const dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02, (unsigned)ne03); |
83 | | - const dim3 block_dims((unsigned)bx, 1, 1); |
| 83 | + const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03); |
| 84 | + const dim3 block_dims((unsigned) bx, 1, 1); |
84 | 85 |
|
85 | 86 | pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>( |
86 | | - src0->data, dst->data, |
87 | | - ne0, ne00, ne01, ne01_packed, ne02, ne03, |
88 | | - src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], |
89 | | - dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], |
90 | | - p0, p1 |
91 | | - ); |
| 87 | + src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], |
| 88 | + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1); |
92 | 89 | } |
0 commit comments