|
1 | 1 | #include "pad_reflect_1d.cuh" |
2 | 2 |
|
3 | | -static __global__ void pad_reflect_1d_kernel_f32( |
4 | | - const void * __restrict__ src0, |
5 | | - void * __restrict__ dst, |
6 | | - const int64_t ne0, |
7 | | - const int64_t ne00, |
8 | | - const int64_t ne01, |
9 | | - const int64_t ne02, |
10 | | - const int64_t ne03, |
11 | | - const int64_t nb00, |
12 | | - const int64_t nb01, |
13 | | - const int64_t nb02, |
14 | | - const int64_t nb03, |
15 | | - const int64_t nb0, |
16 | | - const int64_t nb1, |
17 | | - const int64_t nb2, |
18 | | - const int64_t nb3, |
19 | | - const int p0, |
20 | | - const int p1) { |
21 | | - |
| 3 | +static __global__ __launch_bounds__(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1) void |
| 4 | + pad_reflect_1d_kernel_f32( |
| 5 | + const void * __restrict__ src0, |
| 6 | + void * __restrict__ dst, |
| 7 | + const int64_t ne0, |
| 8 | + const int64_t ne00, |
| 9 | + const uint3 ne01, |
| 10 | + const int64_t ne02, |
| 11 | + const int64_t ne03, |
| 12 | + const int64_t nb00, |
| 13 | + const int64_t nb01, |
| 14 | + const int64_t nb02, |
| 15 | + const int64_t nb03, |
| 16 | + const int64_t nb0, |
| 17 | + const int64_t nb1, |
| 18 | + const int64_t nb2, |
| 19 | + const int64_t nb3, |
| 20 | + const int p0, |
| 21 | + const int p1) { |
22 | 22 | const int64_t i3 = blockIdx.z; |
23 | 23 | const int64_t i2 = blockIdx.y; |
24 | | - const int64_t i1 = blockIdx.x; |
25 | 24 |
|
26 | | - if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { |
| 25 | + const uint2 div_mod_packed = fast_div_modulo(blockIdx.x, ne01); |
| 26 | + const int64_t tile1 = div_mod_packed.y; // i1 |
| 27 | + const int64_t tile0 = div_mod_packed.x; // nth i0 tile |
| 28 | + const int64_t i1 = tile1; |
| 29 | + const int64_t i0 = threadIdx.x + tile0 * blockDim.x; |
| 30 | + |
| 31 | + // ne01.z is original value of unpacked ne01 (see init_fastdiv_values in common.cuh) |
| 32 | + if (i0 >= ne0 || i1 >= ne01.z || i2 >= ne02 || i3 >= ne03) { |
27 | 33 | return; |
28 | 34 | } |
29 | 35 |
|
30 | | - const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01; |
31 | | - char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1; |
32 | | - |
33 | | - for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { |
34 | | - float value; |
| 36 | + const char * src0_ptr = (const char *) src0 + i3 * nb03 + i2 * nb02 + i1 * nb01; |
| 37 | + char * dst_ptr = (char *) dst + i3 * nb3 + i2 * nb2 + i1 * nb1; |
35 | 38 |
|
36 | | - if (i0 < p0) { |
37 | | - // Left padding - reflect |
38 | | - value = *(const float *)(src0_ptr + (p0 - i0) * nb00); |
39 | | - } else if (i0 < ne0 - p1) { |
40 | | - // Middle - copy |
41 | | - value = *(const float *)(src0_ptr + (i0 - p0) * nb00); |
42 | | - } else { |
43 | | - // Right padding - reflect |
44 | | - int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1; |
45 | | - value = *(const float *)(src0_ptr + src_idx * nb00); |
46 | | - } |
| 39 | + const int64_t rel_i0 = i0 - p0; // relative i0 in src0 |
| 40 | + int64_t src_idx; |
47 | 41 |
|
48 | | - *(float *)(dst_ptr + i0 * nb0) = value; |
| 42 | + if (rel_i0 < 0) { |
| 43 | + // Left padding - reflect |
| 44 | + src_idx = -rel_i0; |
| 45 | + } else if (rel_i0 < ne00) { |
| 46 | + // Middle - copy |
| 47 | + src_idx = rel_i0; |
| 48 | + } else { |
| 49 | + // Right padding - reflect |
| 50 | + src_idx = 2 * ne00 - 2 - rel_i0; |
49 | 51 | } |
| 52 | + const float value = *(const float *) (src0_ptr + src_idx * nb00); |
| 53 | + *(float *) (dst_ptr + i0 * nb0) = value; |
50 | 54 | } |
51 | 55 |
|
52 | 56 | void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
53 | | - const ggml_tensor * src0 = dst->src[0]; |
54 | | - cudaStream_t stream = ctx.stream(); |
| 57 | + const ggml_tensor * src0 = dst->src[0]; |
| 58 | + cudaStream_t stream = ctx.stream(); |
55 | 59 |
|
56 | 60 | GGML_ASSERT(src0->type == GGML_TYPE_F32); |
57 | 61 | GGML_ASSERT(dst->type == GGML_TYPE_F32); |
58 | 62 |
|
59 | 63 | const int32_t * opts = (const int32_t *) dst->op_params; |
60 | | - const int p0 = opts[0]; |
61 | | - const int p1 = opts[1]; |
| 64 | + const int p0 = opts[0]; |
| 65 | + const int p1 = opts[1]; |
62 | 66 |
|
63 | | - const int64_t ne00 = src0->ne[0]; |
64 | | - const int64_t ne01 = src0->ne[1]; |
65 | | - const int64_t ne02 = src0->ne[2]; |
66 | | - const int64_t ne03 = src0->ne[3]; |
| 67 | + const int64_t ne00 = src0->ne[0]; |
| 68 | + const int64_t ne01 = src0->ne[1]; |
| 69 | + const uint3 ne01_packed = init_fastdiv_values(ne01); |
| 70 | + const int64_t ne02 = src0->ne[2]; |
| 71 | + const int64_t ne03 = src0->ne[3]; |
67 | 72 |
|
68 | 73 | const int64_t ne0 = dst->ne[0]; |
69 | 74 |
|
| 75 | + // sanity: padded length matches |
70 | 76 | GGML_ASSERT(ne0 == ne00 + p0 + p1); |
71 | 77 |
|
72 | | - const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1); |
73 | | - const dim3 grid_dims(ne01, ne02, ne03); |
| 78 | + constexpr int64_t bx = CUDA_PAD_REFLECT_1D_BLOCK_SIZE; // threads per block (x) |
| 79 | + const int64_t tiles0 = (ne0 + bx - 1) / bx; // number of tiles along i0 |
| 80 | + // grid.x covers i1 and all tiles of i0: [ne01 * tiles0] |
| 81 | + // grid.y covers i2: [ne02] |
| 82 | + // grid.z covers i3: [ne03] |
| 83 | + const dim3 grid_dims((unsigned) (ne01 * tiles0), (unsigned) ne02, (unsigned) ne03); |
| 84 | + const dim3 block_dims((unsigned) bx, 1, 1); |
74 | 85 |
|
75 | 86 | pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>( |
76 | | - src0->data, dst->data, |
77 | | - ne0, ne00, ne01, ne02, ne03, |
78 | | - src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], |
79 | | - dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], |
80 | | - p0, p1 |
81 | | - ); |
| 87 | + src0->data, dst->data, ne0, ne00, ne01_packed, ne02, ne03, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], |
| 88 | + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], p0, p1); |
82 | 89 | } |
0 commit comments