diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index c68ace982a3..2c5a07a641b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -9346,14 +9346,15 @@ static void ggml_compute_forward_get_rel_pos_f32( const float k_scale = MAX((float)qh / kh, 1.0f); const float q_scale = MAX((float)kh / qh, 1.0f); - float * src0_data = (float *) src0->data; - float * dst_data = (float *) dst->data; + const char * src0_d = (const char *) src0->data; + char * dst_d = (char *) dst->data; for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = 0; i1 < ne1; ++i1) { const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale); for (int64_t i0 = 0; i0 < ne0; ++i0) { - dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + const float val = *(const float *) (src0_d + pos*nb01 + i0*nb00); + *(float *) (dst_d + i2*nb2 + i1*nb1 + i0*nb0) = val; } } } @@ -9375,14 +9376,15 @@ static void ggml_compute_forward_get_rel_pos_f16( const float k_scale = MAX((float)qh / kh, 1.0f); const float q_scale = MAX((float)kh / qh, 1.0f); - ggml_fp16_t * src0_data = (ggml_fp16_t *) src0->data; - ggml_fp16_t * dst_data = (ggml_fp16_t *) dst->data; + const char * src0_d = (const char *) src0->data; + char * dst_d = (char *) dst->data; for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = 0; i1 < ne1; ++i1) { const int pos = int(i2*q_scale - i1*k_scale + (kh - 1)*k_scale); for (int64_t i0 = 0; i0 < ne0; ++i0) { - dst_data[i2*ne1*ne0 + i1*ne0 + i0] = src0_data[pos*ne00 + i0]; + const ggml_fp16_t val = *(const ggml_fp16_t *) (src0_d + pos*nb01 + i0*nb00); + *(ggml_fp16_t *) (dst_d + i2*nb2 + i1*nb1 + i0*nb0) = val; } } } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index e64a7cca57f..e8e3aa7416a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2729,6 +2729,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg break; case GGML_OP_GET_REL_POS: ggml_cuda_op_get_rel_pos(ctx, dst); + break; case GGML_OP_SOLVE_TRI: ggml_cuda_op_solve_tri(ctx, dst); break; diff --git a/ggml/src/ggml-cuda/rel-pos.cu b/ggml/src/ggml-cuda/rel-pos.cu index 1d1aba4c737..a8700ae1820 100644 --- a/ggml/src/ggml-cuda/rel-pos.cu +++ b/ggml/src/ggml-cuda/rel-pos.cu @@ -2,22 +2,30 @@ #include "ggml.h" #include "ggml-cuda/rel-pos.cuh" - template -__global__ static void get_rel_pos_kernel(const void * src, void * dst, int C) { - int kh = gridDim.x; - int qh = gridDim.y; - float k_scale = MAX((float)qh / kh, 1.0f); - float q_scale = MAX((float)kh / qh, 1.0f); +__global__ static void get_rel_pos_kernel(const void * src, void * dst, + int C, int kh, int qh, + int nb00, int nb01, + int nb0, int nb1, int nb2) { int ki = blockIdx.x; int qi = blockIdx.y; - int pos = int(qi*q_scale - ki*k_scale + (kh - 1)*k_scale); - int s0 = C; - int s1 = C * kh; + if (ki >= kh || qi >= qh) { + return; + } + + float k_scale = MAX((float) qh / kh, 1.0f); + float q_scale = MAX((float) kh / qh, 1.0f); + + int pos = int(qi * q_scale - ki * k_scale + (kh - 1) * k_scale); + + const char * src_d = (const char *) src; + char * dst_d = (char *) dst; for (int ci = threadIdx.x; ci < C; ci += blockDim.x) { - ((T *) dst)[qi*s1 + ki*s0 + ci] = ((const T *) src)[pos*C + ci]; + const int src_offset = pos * nb01 + ci * nb00; + const int dst_offset = qi * nb2 + ki * nb1 + ci * nb0; + *(T *) (dst_d + dst_offset) = *(const T *) (src_d + src_offset); } } @@ -44,26 +52,28 @@ void ggml_cuda_op_get_rel_pos(ggml_backend_cuda_context & ctx, ggml_tensor * dst int kh = ne1; int qh = ne2; - int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C))); - dim3 grid { (unsigned int)kh, (unsigned int)qh, 1 }; + int num_threads = MIN(CUDA_GET_REL_POS_BLOCK_SIZE, MAX(32, round_to_pow2(C))); + dim3 grid{ (unsigned int) kh, (unsigned int) qh }; - const void * src0_d = (const void *)src0->data; - void * dst_d = (void *)dst->data; + const void * src0_d = (const void *) src0->data; + void * dst_d = (void *) dst->data; cudaStream_t stream = ctx.stream(); - switch (src0->type) - { - case GGML_TYPE_F32: - get_rel_pos_kernel<<>>(src0_d, dst_d, C); - break; - case GGML_TYPE_F16: - get_rel_pos_kernel<<>>(src0_d, dst_d, C); - break; - case GGML_TYPE_BF16: - get_rel_pos_kernel<<>>(src0_d, dst_d, C); - break; - default: - GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type)); - break; + switch (src0->type) { + case GGML_TYPE_F32: + get_rel_pos_kernel + <<>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2); + break; + case GGML_TYPE_F16: + get_rel_pos_kernel + <<>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2); + break; + case GGML_TYPE_BF16: + get_rel_pos_kernel + <<>>(src0_d, dst_d, C, kh, qh, src0->nb[0], src0->nb[1], nb0, nb1, nb2); + break; + default: + GGML_ABORT("%s: unsupported type (%s)\n", __func__, ggml_type_name(src0->type)); + break; } -} \ No newline at end of file +} diff --git a/ggml/src/ggml-cuda/win.cu b/ggml/src/ggml-cuda/win.cu index c9e6793ae5f..83a5cbf708f 100644 --- a/ggml/src/ggml-cuda/win.cu +++ b/ggml/src/ggml-cuda/win.cu @@ -1,6 +1,7 @@ #include "common.cuh" -#include "ggml.h" +#include "convert.cuh" #include "ggml-cuda/win.cuh" +#include "ggml.h" /* @@ -28,7 +29,7 @@ static void ggml_compute_forward_win_part_f16( for (int64_t i3 = 0; i3 < ne3; i3++) { int px = i3 % nep0; int py = (i3 / nep0) % nep1; - int b = i3 / (nep0 * nep1); + int b = i3 / (nep0 * nep1); for (int64_t i2 = 0; i2 < ne2; ++i2) { for (int64_t i1 = 0; i1 < ne1; ++i1) { for (int64_t i0 = 0; i0 < ne0; ++i0) { @@ -38,7 +39,7 @@ static void ggml_compute_forward_win_part_f16( const int64_t i00 = i0; void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; - void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; + void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; if (py*w + i2 >= ne02 || px*w + i1 >= ne01) { *((ggml_fp16_t *) dp) = 0; @@ -138,7 +139,7 @@ __global__ static void win_part_kernel( if (py*p.w + i2 >= p.ne2 || px*p.w + i1 >= p.ne1) { for (int i0 = threadIdx.x; i0 < p.C; i0 += blockDim.x) { char * dp = (char *)dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; - *((T *) dp) = 0; + *((T *) dp) = ggml_cuda_cast(0.0f); } return; } @@ -210,7 +211,7 @@ static unsigned int round_to_pow2(unsigned int v) { v++; return v; -} +} void ggml_cuda_op_win_part(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; @@ -297,12 +298,12 @@ static void ggml_compute_forward_win_unpart_f16( for (int64_t i0 = 0; i0 < ne0; ++i0) { const int ip2 = i2/w; const int ip1 = i1/w; - + const int64_t i03 = i3*npx*npy + ip2*npx + ip1; const int64_t i02 = i2%w; const int64_t i01 = i1%w; const int64_t i00 = i0; - + void * sp = ((void *) src0->data) + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00; void * dp = ((void *) dst->data) + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f94daacc119..bf565ac24a3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7871,8 +7871,20 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_get_rel_pos(type, 13, 7, 7, v)); // Square large: 14x14 attention test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 14, v)); + // Square large: 16x16 attention + test_cases.emplace_back(new test_get_rel_pos(type, 31, 16, 16, v)); // Rectangular: 14x7 attention test_cases.emplace_back(new test_get_rel_pos(type, 27, 14, 7, v)); + // Rectangular: 7x14 attention + test_cases.emplace_back(new test_get_rel_pos(type, 27, 7, 14, v)); + // Rectangular: 16x8 attention + test_cases.emplace_back(new test_get_rel_pos(type, 31, 16, 8, v)); + // Rectangular: 8x16 attention + test_cases.emplace_back(new test_get_rel_pos(type, 31, 8, 16, v)); + // Rectangular: 28x14 attention + test_cases.emplace_back(new test_get_rel_pos(type, 55, 28, 14, v)); + // Rectangular: 14x28 attention + test_cases.emplace_back(new test_get_rel_pos(type, 55, 14, 28, v)); // Edge case: 1x1 attention (minimum) test_cases.emplace_back(new test_get_rel_pos(type, 1, 1, 1, v)); }