Skip to content

Commit b1ab918

Browse files
cuda : add Pad Reflect 1D support (ggml-org#14659)
* Add Pad Reflect 1D CUDA support * Update ggml/src/ggml-cuda/pad_reflect_1d.cu Co-authored-by: Johannes Gäßler <[email protected]> --------- Co-authored-by: Johannes Gäßler <[email protected]>
1 parent 9ebebef commit b1ab918

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
#include "ggml-cuda/wkv.cuh"
5050
#include "ggml-cuda/gla.cuh"
5151
#include "ggml-cuda/set-rows.cuh"
52+
#include "ggml-cuda/pad_reflect_1d.cuh"
5253
#include "ggml.h"
5354

5455
#include <algorithm>
@@ -2352,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23522353
case GGML_OP_PAD:
23532354
ggml_cuda_op_pad(ctx, dst);
23542355
break;
2356+
case GGML_OP_PAD_REFLECT_1D:
2357+
ggml_cuda_op_pad_reflect_1d(ctx, dst);
2358+
break;
23552359
case GGML_OP_ARANGE:
23562360
ggml_cuda_op_arange(ctx, dst);
23572361
break;
@@ -3490,6 +3494,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
34903494
return ggml_is_contiguous(op->src[0]);
34913495
case GGML_OP_UPSCALE:
34923496
case GGML_OP_PAD:
3497+
case GGML_OP_PAD_REFLECT_1D:
34933498
case GGML_OP_ARANGE:
34943499
case GGML_OP_TIMESTEP_EMBEDDING:
34953500
case GGML_OP_LEAKY_RELU:

ggml/src/ggml-cuda/pad_reflect_1d.cu

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#include "pad_reflect_1d.cuh"
2+
3+
static __global__ void pad_reflect_1d_kernel_f32(
4+
const void * __restrict__ src0,
5+
void * __restrict__ dst,
6+
const int64_t ne0,
7+
const int64_t ne00,
8+
const int64_t ne01,
9+
const int64_t ne02,
10+
const int64_t ne03,
11+
const int64_t nb00,
12+
const int64_t nb01,
13+
const int64_t nb02,
14+
const int64_t nb03,
15+
const int64_t nb0,
16+
const int64_t nb1,
17+
const int64_t nb2,
18+
const int64_t nb3,
19+
const int p0,
20+
const int p1) {
21+
22+
const int64_t i3 = blockIdx.z;
23+
const int64_t i2 = blockIdx.y;
24+
const int64_t i1 = blockIdx.x;
25+
26+
if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) {
27+
return;
28+
}
29+
30+
const char * src0_ptr = (const char *)src0 + i3*nb03 + i2*nb02 + i1*nb01;
31+
char * dst_ptr = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1;
32+
33+
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
34+
float value;
35+
36+
if (i0 < p0) {
37+
// Left padding - reflect
38+
value = *(const float *)(src0_ptr + (p0 - i0) * nb00);
39+
} else if (i0 < ne0 - p1) {
40+
// Middle - copy
41+
value = *(const float *)(src0_ptr + (i0 - p0) * nb00);
42+
} else {
43+
// Right padding - reflect
44+
int64_t src_idx = (ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1;
45+
value = *(const float *)(src0_ptr + src_idx * nb00);
46+
}
47+
48+
*(float *)(dst_ptr + i0 * nb0) = value;
49+
}
50+
}
51+
52+
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
53+
const ggml_tensor * src0 = dst->src[0];
54+
cudaStream_t stream = ctx.stream();
55+
56+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
57+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
58+
59+
const int32_t * opts = (const int32_t *) dst->op_params;
60+
const int p0 = opts[0];
61+
const int p1 = opts[1];
62+
63+
const int64_t ne00 = src0->ne[0];
64+
const int64_t ne01 = src0->ne[1];
65+
const int64_t ne02 = src0->ne[2];
66+
const int64_t ne03 = src0->ne[3];
67+
68+
const int64_t ne0 = dst->ne[0];
69+
70+
GGML_ASSERT(ne0 == ne00 + p0 + p1);
71+
72+
const dim3 block_dims(CUDA_PAD_REFLECT_1D_BLOCK_SIZE, 1, 1);
73+
const dim3 grid_dims(ne01, ne02, ne03);
74+
75+
pad_reflect_1d_kernel_f32<<<grid_dims, block_dims, 0, stream>>>(
76+
src0->data, dst->data,
77+
ne0, ne00, ne01, ne02, ne03,
78+
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
79+
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
80+
p0, p1
81+
);
82+
}

ggml/src/ggml-cuda/pad_reflect_1d.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_PAD_REFLECT_1D_BLOCK_SIZE 256
4+
5+
void ggml_cuda_op_pad_reflect_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)