-
Notifications
You must be signed in to change notification settings - Fork 13.5k
CUDA: fuse rope + set_rows #16884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
CUDA: fuse rope + set_rows #16884
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| #include "ggml-cuda/common.cuh" | ||
| #include "ggml.h" | ||
| #include "rope.cuh" | ||
|
|
||
| struct rope_corr_dims { | ||
|
|
@@ -37,11 +39,23 @@ static __device__ void rope_yarn( | |
| } | ||
| } | ||
|
|
||
| template<bool forward, bool has_ff, typename T> | ||
| static __global__ void rope_norm( | ||
| const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, | ||
| const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, | ||
| const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { | ||
| template <bool forward, bool has_ff, typename T, typename D> | ||
| static __global__ void rope_norm(const T * x, | ||
| D * dst, | ||
| const int ne0, | ||
| const int ne1, | ||
| const int s1, | ||
| const int s2, | ||
| const int n_dims, | ||
| const int32_t * pos, | ||
| const float freq_scale, | ||
| const float ext_factor, | ||
| const float attn_factor, | ||
| const rope_corr_dims corr_dims, | ||
| const float theta_scale, | ||
| const float * freq_factors, | ||
| const int64_t * row_indices, | ||
| const int set_rows_stride) { | ||
| const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||
|
|
||
| if (i0 >= ne0) { | ||
|
|
@@ -53,12 +67,19 @@ static __global__ void rope_norm( | |
| const int row_x = row_dst % ne1; | ||
| const int channel_x = row_dst / ne1; | ||
|
|
||
| const int idst = row_dst*ne0 + i0; | ||
| int idst = row_dst * ne0 + i0; | ||
| const int ix = channel_x*s2 + row_x*s1 + i0; | ||
|
|
||
| // Fusion optimization: ROPE + VIEW + SET_ROWS. | ||
| // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. | ||
| if (set_rows_stride != 0) { | ||
| idst = row_x * ne0 + i0; | ||
| idst += row_indices[channel_x] * set_rows_stride; | ||
| } | ||
|
|
||
| if (i0 >= n_dims) { | ||
| dst[idst + 0] = x[ix + 0]; | ||
| dst[idst + 1] = x[ix + 1]; | ||
| dst[idst + 0] = D(x[ix + 0]); | ||
| dst[idst + 1] = D(x[ix + 1]); | ||
|
|
||
| return; | ||
| } | ||
|
|
@@ -75,15 +96,27 @@ static __global__ void rope_norm( | |
| const float x0 = x[ix + 0]; | ||
| const float x1 = x[ix + 1]; | ||
|
|
||
| dst[idst + 0] = x0*cos_theta - x1*sin_theta; | ||
| dst[idst + 1] = x0*sin_theta + x1*cos_theta; | ||
| dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta); | ||
| dst[idst + 1] = D(x0 * sin_theta + x1 * cos_theta); | ||
|
Comment on lines
+99
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When you're already working on optimizing RoPE: I think the memory access pattern here is suboptimal because there are gaps between each thread and I don't know whether the compiler is smart enough to combine the first and second write into a single one. I would suggest grouping the values as |
||
| } | ||
|
|
||
| template<bool forward, bool has_ff, typename T> | ||
| static __global__ void rope_neox( | ||
| const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, | ||
| const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, | ||
| const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { | ||
| template <bool forward, bool has_ff, typename T, typename D> | ||
| static __global__ void rope_neox(const T * x, | ||
| D * dst, | ||
| const int ne0, | ||
| const int ne1, | ||
| const int s1, | ||
| const int s2, | ||
| const int n_dims, | ||
| const int32_t * pos, | ||
| const float freq_scale, | ||
| const float ext_factor, | ||
| const float attn_factor, | ||
| const rope_corr_dims corr_dims, | ||
| const float theta_scale, | ||
| const float * freq_factors, | ||
| const int64_t * row_indices, | ||
| const int set_rows_stride) { | ||
| const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); | ||
|
|
||
| if (i0 >= ne0) { | ||
|
|
@@ -95,12 +128,19 @@ static __global__ void rope_neox( | |
| const int row_x = row_dst % ne1; | ||
| const int channel_x = row_dst / ne1; | ||
|
|
||
| const int idst = row_dst*ne0 + i0/2; | ||
| int idst = row_dst * ne0 + i0 / 2; | ||
| const int ix = channel_x*s2 + row_x*s1 + i0/2; | ||
|
|
||
| // Fusion optimization: ROPE + VIEW + SET_ROWS. | ||
| // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. | ||
| if (set_rows_stride != 0) { | ||
| idst = row_x * ne0 + i0 / 2; | ||
| idst += row_indices[channel_x] * set_rows_stride; | ||
| } | ||
|
|
||
| if (i0 >= n_dims) { | ||
| dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; | ||
| dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; | ||
| dst[idst + i0 / 2 + 0] = D(x[ix + i0 / 2 + 0]); | ||
| dst[idst + i0 / 2 + 1] = D(x[ix + i0 / 2 + 1]); | ||
|
|
||
| return; | ||
| } | ||
|
|
@@ -117,8 +157,8 @@ static __global__ void rope_neox( | |
| const float x0 = x[ix + 0]; | ||
| const float x1 = x[ix + n_dims/2]; | ||
|
|
||
| dst[idst + 0] = x0*cos_theta - x1*sin_theta; | ||
| dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; | ||
| dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta); | ||
| dst[idst + n_dims / 2] = D(x0 * sin_theta + x1 * cos_theta); | ||
| } | ||
|
|
||
| template<bool forward, bool has_ff, typename T> | ||
|
|
@@ -238,11 +278,25 @@ static __global__ void rope_vision( | |
| dst[idst + n_dims] = x0*sin_theta + x1*cos_theta; | ||
| } | ||
|
|
||
| template<bool forward, typename T> | ||
| static void rope_norm_cuda( | ||
| const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, | ||
| const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, | ||
| const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { | ||
| template <bool forward, typename T, typename D> | ||
| static void rope_norm_cuda(const T * x, | ||
| D * dst, | ||
| const int ne0, | ||
| const int ne1, | ||
| const int s1, | ||
| const int s2, | ||
| const int n_dims, | ||
| const int nr, | ||
| const int32_t * pos, | ||
| const float freq_scale, | ||
| const float freq_base, | ||
| const float ext_factor, | ||
| const float attn_factor, | ||
| const rope_corr_dims corr_dims, | ||
| const float * freq_factors, | ||
| const int64_t * row_indices, | ||
| const int set_rows_stride, | ||
| cudaStream_t stream) { | ||
| GGML_ASSERT(ne0 % 2 == 0); | ||
| const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); | ||
| const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||
|
|
@@ -252,20 +306,34 @@ static void rope_norm_cuda( | |
|
|
||
| if (freq_factors == nullptr) { | ||
| rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>( | ||
| x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, | ||
| attn_factor, corr_dims, theta_scale, freq_factors); | ||
| x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, | ||
| freq_factors, row_indices, set_rows_stride); | ||
| } else { | ||
| rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>( | ||
| x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, | ||
| attn_factor, corr_dims, theta_scale, freq_factors); | ||
| x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, | ||
| freq_factors, row_indices, set_rows_stride); | ||
| } | ||
| } | ||
|
|
||
| template<bool forward, typename T> | ||
| static void rope_neox_cuda( | ||
| const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, | ||
| const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, | ||
| const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { | ||
| template <bool forward, typename T, typename D> | ||
| static void rope_neox_cuda(const T * x, | ||
| D * dst, | ||
| const int ne0, | ||
| const int ne1, | ||
| const int s1, | ||
| const int s2, | ||
| const int n_dims, | ||
| const int nr, | ||
| const int32_t * pos, | ||
| const float freq_scale, | ||
| const float freq_base, | ||
| const float ext_factor, | ||
| const float attn_factor, | ||
| const rope_corr_dims corr_dims, | ||
| const float * freq_factors, | ||
| const int64_t * row_indices, | ||
| const int set_rows_stride, | ||
| cudaStream_t stream) { | ||
| GGML_ASSERT(ne0 % 2 == 0); | ||
| const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); | ||
| const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||
|
|
@@ -274,13 +342,13 @@ static void rope_neox_cuda( | |
| const float theta_scale = powf(freq_base, -2.0f/n_dims); | ||
|
|
||
| if (freq_factors == nullptr) { | ||
| rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>( | ||
| x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, | ||
| attn_factor, corr_dims, theta_scale, freq_factors); | ||
| rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>( | ||
| x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, | ||
| freq_factors, row_indices, set_rows_stride); | ||
| } else { | ||
| rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>( | ||
| x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, | ||
| attn_factor, corr_dims, theta_scale, freq_factors); | ||
| rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>( | ||
| x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, | ||
| freq_factors, row_indices, set_rows_stride); | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -333,20 +401,35 @@ static void rope_vision_cuda( | |
| } | ||
|
|
||
| template <bool forward> | ||
| void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, | ||
| ggml_tensor * dst, | ||
| const ggml_tensor * set_rows = nullptr) { | ||
| const ggml_tensor * src0 = dst->src[0]; | ||
| const ggml_tensor * src1 = dst->src[1]; | ||
| const ggml_tensor * src2 = dst->src[2]; | ||
|
|
||
| const float * src0_d = (const float *)src0->data; | ||
| const float * src1_d = (const float *)src1->data; | ||
|
|
||
| float * dst_d = (float *)dst->data; | ||
| void * dst_d = dst->data; | ||
| const int64_t * row_indices = nullptr; | ||
| ggml_type dst_type = dst->type; | ||
| int set_rows_stride = 0; | ||
|
|
||
| if (set_rows != nullptr) { | ||
| GGML_ASSERT(forward); | ||
| dst_d = set_rows->data; | ||
| row_indices = (const int64_t *) set_rows->src[1]->data; | ||
| dst_type = set_rows->type; | ||
| set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type); | ||
| } | ||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); | ||
| GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); | ||
| GGML_ASSERT(src0->type == dst->type); | ||
| // When not fused, src0 and dst types must match | ||
| // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16 | ||
| GGML_ASSERT(src0->type == dst->type || dst->type == GGML_TYPE_F16); | ||
am17an marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| const int64_t ne00 = src0->ne[0]; // head dims | ||
| const int64_t ne01 = src0->ne[1]; // num heads | ||
|
|
@@ -404,14 +487,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) | |
|
|
||
| // compute | ||
| if (is_neox) { | ||
| if (src0->type == GGML_TYPE_F32) { | ||
| rope_neox_cuda<forward>( | ||
| (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, | ||
| freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); | ||
| } else if (src0->type == GGML_TYPE_F16) { | ||
| rope_neox_cuda<forward>( | ||
| (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, | ||
| freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); | ||
| if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { | ||
| rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, | ||
| nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, | ||
| freq_factors, row_indices, set_rows_stride, stream); | ||
| } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { | ||
| rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, | ||
| nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, | ||
| freq_factors, row_indices, set_rows_stride, stream); | ||
| } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { | ||
| rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, | ||
| pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, | ||
| freq_factors, row_indices, set_rows_stride, stream); | ||
| } else { | ||
| GGML_ABORT("fatal error"); | ||
| } | ||
|
|
@@ -440,14 +527,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) | |
| GGML_ABORT("fatal error"); | ||
| } | ||
| } else { | ||
| if (src0->type == GGML_TYPE_F32) { | ||
| rope_norm_cuda<forward>( | ||
| (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, | ||
| freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); | ||
| } else if (src0->type == GGML_TYPE_F16) { | ||
| rope_norm_cuda<forward>( | ||
| (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, | ||
| freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); | ||
| if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { | ||
| rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, | ||
| nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, | ||
| freq_factors, row_indices, set_rows_stride, stream); | ||
| } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { | ||
| rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, | ||
| nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, | ||
| freq_factors, row_indices, set_rows_stride, stream); | ||
| } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { | ||
| rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, | ||
| pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, | ||
| freq_factors, row_indices, set_rows_stride, stream); | ||
| } else { | ||
| GGML_ABORT("fatal error"); | ||
| } | ||
|
|
@@ -461,3 +552,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | |
| void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| ggml_cuda_op_rope_impl<false>(ctx, dst); | ||
| } | ||
|
|
||
| void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) { | ||
| ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would suggest you use
ggml_cuda_castdefined inconvert.cuh. Otherwise there will potentially be issues with FP16 <-> BF16 conversions.