| 
 | 1 | +#include "pad_reflect_1d.cuh"  | 
 | 2 | + | 
 | 3 | +static __global__ void pad_reflect_1d_kernel_f32(  | 
 | 4 | +    const void * __restrict__ src0,  | 
 | 5 | +    void * __restrict__ dst,  | 
 | 6 | +    const int64_t ne0,  | 
 | 7 | +    const int64_t ne00,  | 
 | 8 | +    const int64_t ne01,  | 
 | 9 | +    const int64_t ne02,  | 
 | 10 | +    const int64_t ne03,  | 
 | 11 | +    const int64_t nb00,  | 
 | 12 | +    const int64_t nb01,  | 
 | 13 | +    const int64_t nb02,  | 
 | 14 | +    const int64_t nb03,  | 
 | 15 | +    const int64_t nb0,  | 
 | 16 | +    const int64_t nb1,  | 
 | 17 | +    const int64_t nb2,  | 
 | 18 | +    const int64_t nb3,  | 
 | 19 | +    const int p0,  | 
 | 20 | +    const int p1) {  | 
 | 21 | + | 
 | 22 | +    const int64_t i3 = blockIdx.z;  | 
 | 23 | +    const int64_t i2 = blockIdx.y;  | 
 | 24 | +    const int64_t i1 = blockIdx.x;  | 
 | 25 | + | 
 | 26 | +    if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {  | 
 | 27 | +        return;  | 
 | 28 | +    }  | 
 | 29 | + | 
 | 30 | +    const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;  | 
 | 31 | +    char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;  | 
 | 32 | + | 
 | 33 | +    for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {  | 
 | 34 | +        float value;  | 
 | 35 | + | 
 | 36 | +        if (i0 < p0) {  | 
 | 37 | +            // Left padding - reflect  | 
 | 38 | +            value = *(const float *)(src0_ptr + (p0 - i0) * nb00);  | 
 | 39 | +        } else if (i0 < ne0 - p1) {  | 
 | 40 | +            // Middle - copy  | 
 | 41 | +            value = *(const float *)(src0_ptr + (i0 - p0) * nb00);  | 
 | 42 | +        } else {  | 
 | 43 | +            // Right padding - reflect  | 
 | 44 | +            int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;  | 
 | 45 | +            value = *(const float *)(src0_ptr + src_idx * nb00);  | 
 | 46 | +        }  | 
 | 47 | + | 
 | 48 | +        *(float *)(dst_ptr + i0 * nb0) = value;  | 
 | 49 | +    }  | 
 | 50 | +}  | 
 | 51 | + | 
 | 52 | +void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {  | 
 | 53 | +    const ggml_tensor * src0 = dst->src[0];  | 
 | 54 | +    cudaStream_t stream = ctx.stream();  | 
 | 55 | + | 
 | 56 | +    GGML_ASSERT(src0->type == GGML_TYPE_F32);  | 
 | 57 | +    GGML_ASSERT(dst->type == GGML_TYPE_F32);  | 
 | 58 | + | 
 | 59 | +    const int32_t * opts = (const int32_t *) dst->op_params;  | 
 | 60 | +    const int p0 = opts[0];  | 
 | 61 | +    const int p1 = opts[1];  | 
 | 62 | + | 
 | 63 | +    const int64_t ne00 = src0->ne[0];  | 
 | 64 | +    const int64_t ne01 = src0->ne[1];  | 
 | 65 | +    const int64_t ne02 = src0->ne[2];  | 
 | 66 | +    const int64_t ne03 = src0->ne[3];  | 
 | 67 | + | 
 | 68 | +    const int64_t ne0 = dst->ne[0];  | 
 | 69 | + | 
 | 70 | +    GGML_ASSERT(ne0 == ne00 + p0 + p1);  | 
 | 71 | + | 
 | 72 | +    const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);  | 
 | 73 | +    const dim3 grid_dims(ne01, ne02, ne03);  | 
 | 74 | + | 
 | 75 | +    pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(  | 
 | 76 | +        src0->data, dst->data,  | 
 | 77 | +        ne0, ne00, ne01, ne02, ne03,  | 
 | 78 | +        src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],  | 
 | 79 | +        dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],  | 
 | 80 | +        p0, p1  | 
 | 81 | +    );  | 
 | 82 | +}  | 
0 commit comments