22#include " ggml.h"
33#include " ggml-cuda/rel-pos.cuh"
44
5-
65template <typename T>
7- __global__ static void get_rel_pos_kernel (const void * src, void * dst, int C) {
8- int kh = gridDim .x ;
9- int qh = gridDim .y ;
10- float k_scale = MAX ((float )qh / kh, 1 .0f );
11- float q_scale = MAX ((float )kh / qh, 1 .0f );
6+ __global__ static void get_rel_pos_kernel (const void * src, void * dst,
7+ int C, int kh, int qh,
8+ int nb00, int nb01,
9+ int nb0, int nb1, int nb2) {
1210 int ki = blockIdx .x ;
1311 int qi = blockIdx .y ;
14- int pos = int (qi*q_scale - ki*k_scale + (kh - 1 )*k_scale);
1512
16- int s0 = C;
17- int s1 = C * kh;
13+ if (ki >= kh || qi >= qh) {
14+ return ;
15+ }
16+
17+ float k_scale = MAX ((float ) qh / kh, 1 .0f );
18+ float q_scale = MAX ((float ) kh / qh, 1 .0f );
19+
20+ int pos = int (qi * q_scale - ki * k_scale + (kh - 1 ) * k_scale);
21+
22+ const char * src_d = (const char *) src;
23+ char * dst_d = (char *) dst;
1824
1925 for (int ci = threadIdx .x ; ci < C; ci += blockDim .x ) {
20- ((T *) dst)[qi*s1 + ki*s0 + ci] = ((const T *) src)[pos*C + ci];
26+ const int src_offset = pos * nb01 + ci * nb00;
27+ const int dst_offset = qi * nb2 + ki * nb1 + ci * nb0;
28+ *(T *) (dst_d + dst_offset) = *(const T *) (src_d + src_offset);
2129 }
2230}
2331
@@ -44,26 +52,28 @@ void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst
4452 int kh = ne1;
4553 int qh = ne2;
4654
47- int num_threads = MIN (CUDA_GET_REL_POS_BLOCK_SIZE, MAX (32 , round_to_pow2 (C)));
48- dim3 grid { (unsigned int )kh, (unsigned int )qh, 1 };
55+ int num_threads = MIN (CUDA_GET_REL_POS_BLOCK_SIZE, MAX (32 , round_to_pow2 (C)));
56+ dim3 grid{ (unsigned int ) kh, (unsigned int ) qh };
4957
50- const void * src0_d = (const void *)src0->data ;
51- void * dst_d = (void *)dst->data ;
58+ const void * src0_d = (const void *) src0->data ;
59+ void * dst_d = (void *) dst->data ;
5260 cudaStream_t stream = ctx.stream ();
5361
54- switch (src0->type )
55- {
56- case GGML_TYPE_F32:
57- get_rel_pos_kernel<float ><<<grid, num_threads, 0 , stream>>> (src0_d, dst_d, C);
58- break ;
59- case GGML_TYPE_F16:
60- get_rel_pos_kernel<half><<<grid, num_threads, 0 , stream>>> (src0_d, dst_d, C);
61- break ;
62- case GGML_TYPE_BF16:
63- get_rel_pos_kernel<nv_bfloat16><<<grid, num_threads, 0 , stream>>> (src0_d, dst_d, C);
64- break ;
65- default :
66- GGML_ABORT (" %s: unsupported type (%s)\n " , __func__, ggml_type_name (src0->type ));
67- break ;
62+ switch (src0->type ) {
63+ case GGML_TYPE_F32:
64+ get_rel_pos_kernel<float >
65+ <<<grid, num_threads, 0 , stream>>> (src0_d, dst_d, C, kh, qh, src0->nb [0 ], src0->nb [1 ], nb0, nb1, nb2);
66+ break ;
67+ case GGML_TYPE_F16:
68+ get_rel_pos_kernel<half>
69+ <<<grid, num_threads, 0 , stream>>> (src0_d, dst_d, C, kh, qh, src0->nb [0 ], src0->nb [1 ], nb0, nb1, nb2);
70+ break ;
71+ case GGML_TYPE_BF16:
72+ get_rel_pos_kernel<nv_bfloat16>
73+ <<<grid, num_threads, 0 , stream>>> (src0_d, dst_d, C, kh, qh, src0->nb [0 ], src0->nb [1 ], nb0, nb1, nb2);
74+ break ;
75+ default :
76+ GGML_ABORT (" %s: unsupported type (%s)\n " , __func__, ggml_type_name (src0->type ));
77+ break ;
6878 }
69- }
79+ }
0 commit comments