Skip to content

Commit fea493a

Browse files
cuda: support incontiguous inputs for get_rel_pos
1 parent d1b1757 commit fea493a

File tree

1 file changed

+39
-29
lines changed

1 file changed

+39
-29
lines changed

ggml/src/ggml-cuda/rel-pos.cu

Lines changed: 39 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,30 @@
22
#include "ggml.h"
33
#include "ggml-cuda/rel-pos.cuh"
44

5-
65
template <typename T>
7-
__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) {
8-
int kh = gridDim.x;
9-
int qh = gridDim.y;
10-
float k_scale = MAX((float)qh / kh, 1.0f);
11-
float q_scale = MAX((float)kh / qh, 1.0f);
6+
__global__ static void get_rel_pos_kernel(const void * src, void * dst,
7+
int C, int kh, int qh,
8+
int nb00, int nb01,
9+
int nb0, int nb1, int nb2) {
1210
int ki = blockIdx.x;
1311
int qi = blockIdx.y;
14-
int pos = int(qi*q_scale - ki*k_scale + (kh - 1)*k_scale);
1512

16-
int s0 = C;
17-
int s1 = C * kh;
13+
if (ki >= kh || qi >= qh) {
14+
return;
15+
}
16+
17+
float k_scale = MAX((float) qh / kh, 1.0f);
18+
float q_scale = MAX((float) kh / qh, 1.0f);
19+
20+
int pos = int(qi * q_scale - ki * k_scale + (kh - 1) * k_scale);
21+
22+
const char * src_d = (const char *) src;
23+
char * dst_d = (char *) dst;
1824

1925
for (int ci = threadIdx.x; ci < C; ci += blockDim.x) {
20-
((T *) dst)[qi*s1 + ki*s0 + ci] = ((const T *) src)[pos*C + ci];
26+
const int src_offset = pos * nb01 + ci * nb00;
27+
const int dst_offset = qi * nb2 + ki * nb1 + ci * nb0;
28+
*(T *) (dst_d + dst_offset) = *(const T *) (src_d + src_offset);
2129
}
2230
}
2331

@@ -44,26 +52,28 @@ void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst
4452
int kh = ne1;
4553
int qh = ne2;
4654

47-
int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C)));
48-
dim3 grid { (unsigned int)kh, (unsigned int)qh, 1 };
55+
int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C)));
56+
dim3 grid{ (unsigned int) kh, (unsigned int) qh };
4957

50-
const void * src0_d = (const void *)src0->data;
51-
void * dst_d = (void *)dst->data;
58+
const void * src0_d = (const void *) src0->data;
59+
void * dst_d = (void *) dst->data;
5260
cudaStream_t stream = ctx.stream();
5361

54-
switch (src0->type)
55-
{
56-
case GGML_TYPE_F32:
57-
get_rel_pos_kernel<float><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
58-
break;
59-
case GGML_TYPE_F16:
60-
get_rel_pos_kernel<half><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
61-
break;
62-
case GGML_TYPE_BF16:
63-
get_rel_pos_kernel<nv_bfloat16><<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C);
64-
break;
65-
default:
66-
GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type));
67-
break;
62+
switch (src0->type) {
63+
case GGML_TYPE_F32:
64+
get_rel_pos_kernel<float>
65+
<<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2);
66+
break;
67+
case GGML_TYPE_F16:
68+
get_rel_pos_kernel<half>
69+
<<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2);
70+
break;
71+
case GGML_TYPE_BF16:
72+
get_rel_pos_kernel<nv_bfloat16>
73+
<<<grid, num_threads, 0, stream>>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2);
74+
break;
75+
default:
76+
GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type));
77+
break;
6878
}
69-
}
79+
}

0 commit comments

Comments
 (0)