diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 61a8f1df87de1..b53705ffb7933 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2964,6 +2964,36 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { } #endif +static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, + const ggml_tensor * view, + const ggml_tensor * set_rows) { + // ne3 not tested + if (rope->src[0]->ne[3] != 1) { + return false; + } + + if (set_rows->type != GGML_TYPE_F32 && set_rows->type != GGML_TYPE_F16) { + return false; + } + + if (set_rows->src[1]->type != GGML_TYPE_I64) { + return false; + } + + // The view should flatten two dims of rope into one dim + if (!ggml_is_contiguous(view) || view->ne[0] != rope->ne[0] * rope->ne[1]) { + return false; + } + + // Only norm/neox shaders have the fusion code + const int mode = ((const int32_t *) rope->op_params)[2]; + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX) { + return false; + } + + return true; +} + static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops, std::initializer_list unary_ops) { #ifndef NDEBUG const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); @@ -3039,6 +3069,16 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, } } + if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { + const ggml_tensor * rope = cgraph->nodes[node_idx]; + const ggml_tensor * view = cgraph->nodes[node_idx + 1]; + const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2]; + + if (ggml_cuda_should_fuse_rope_set_rows(rope, view, set_rows)) { + return true; + } + } + if (!ggml_can_fuse(cgraph, node_idx, ops)) { return false; } @@ -3170,6 +3210,15 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { + ggml_tensor * rope = cgraph->nodes[i]; + ggml_tensor * set_rows = cgraph->nodes[i + 2]; + + ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); + i += 2; + continue; + } + if (node->op == GGML_OP_MUL) { int current_node = i + 1; int num_views = 0; diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 78ed7f519abb9..ab14cc808a26d 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -1,3 +1,5 @@ +#include "ggml-cuda/common.cuh" +#include "ggml.h" #include "rope.cuh" struct rope_corr_dims { @@ -37,11 +39,23 @@ static __device__ void rope_yarn( } } -template -static __global__ void rope_norm( - const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { +template +static __global__ void rope_norm(const T * x, + D * dst, + const int ne0, + const int ne1, + const int s1, + const int s2, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const int64_t * row_indices, + const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -53,12 +67,19 @@ static __global__ void rope_norm( const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; - const int idst = row_dst*ne0 + i0; + int idst = row_dst * ne0 + i0; const int ix = channel_x*s2 + row_x*s1 + i0; + // Fusion optimization: ROPE + VIEW + SET_ROWS. + // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. + if (set_rows_stride != 0) { + idst = row_x * ne0 + i0; + idst += row_indices[channel_x] * set_rows_stride; + } + if (i0 >= n_dims) { - dst[idst + 0] = x[ix + 0]; - dst[idst + 1] = x[ix + 1]; + dst[idst + 0] = D(x[ix + 0]); + dst[idst + 1] = D(x[ix + 1]); return; } @@ -75,15 +96,27 @@ static __global__ void rope_norm( const float x0 = x[ix + 0]; const float x1 = x[ix + 1]; - dst[idst + 0] = x0*cos_theta - x1*sin_theta; - dst[idst + 1] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta); + dst[idst + 1] = D(x0 * sin_theta + x1 * cos_theta); } -template -static __global__ void rope_neox( - const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { +template +static __global__ void rope_neox(const T * x, + D * dst, + const int ne0, + const int ne1, + const int s1, + const int s2, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const int64_t * row_indices, + const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); if (i0 >= ne0) { @@ -95,12 +128,19 @@ static __global__ void rope_neox( const int row_x = row_dst % ne1; const int channel_x = row_dst / ne1; - const int idst = row_dst*ne0 + i0/2; + int idst = row_dst * ne0 + i0 / 2; const int ix = channel_x*s2 + row_x*s1 + i0/2; + // Fusion optimization: ROPE + VIEW + SET_ROWS. + // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. + if (set_rows_stride != 0) { + idst = row_x * ne0 + i0 / 2; + idst += row_indices[channel_x] * set_rows_stride; + } + if (i0 >= n_dims) { - dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; - dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; + dst[idst + i0 / 2 + 0] = D(x[ix + i0 / 2 + 0]); + dst[idst + i0 / 2 + 1] = D(x[ix + i0 / 2 + 1]); return; } @@ -117,8 +157,8 @@ static __global__ void rope_neox( const float x0 = x[ix + 0]; const float x1 = x[ix + n_dims/2]; - dst[idst + 0] = x0*cos_theta - x1*sin_theta; - dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; + dst[idst + 0] = D(x0 * cos_theta - x1 * sin_theta); + dst[idst + n_dims / 2] = D(x0 * sin_theta + x1 * cos_theta); } template @@ -238,11 +278,25 @@ static __global__ void rope_vision( dst[idst + n_dims] = x0*sin_theta + x1*cos_theta; } -template -static void rope_norm_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { +template +static void rope_norm_cuda(const T * x, + D * dst, + const int ne0, + const int ne1, + const int s1, + const int s2, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const int64_t * row_indices, + const int set_rows_stride, + cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -252,20 +306,34 @@ static void rope_norm_cuda( if (freq_factors == nullptr) { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors); + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, + freq_factors, row_indices, set_rows_stride); } else { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors); + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, + freq_factors, row_indices, set_rows_stride); } } -template -static void rope_neox_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { +template +static void rope_neox_cuda(const T * x, + D * dst, + const int ne0, + const int ne1, + const int s1, + const int s2, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const int64_t * row_indices, + const int set_rows_stride, + cudaStream_t stream) { GGML_ASSERT(ne0 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); @@ -274,13 +342,13 @@ static void rope_neox_cuda( const float theta_scale = powf(freq_base, -2.0f/n_dims); if (freq_factors == nullptr) { - rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors); + rope_neox<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, + freq_factors, row_indices, set_rows_stride); } else { - rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, - attn_factor, corr_dims, theta_scale, freq_factors); + rope_neox<<>>( + x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, + freq_factors, row_indices, set_rows_stride); } } @@ -333,7 +401,9 @@ static void rope_vision_cuda( } template -void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, + ggml_tensor * dst, + const ggml_tensor * set_rows = nullptr) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src2 = dst->src[2]; @@ -341,12 +411,25 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const float * src0_d = (const float *)src0->data; const float * src1_d = (const float *)src1->data; - float * dst_d = (float *)dst->data; + void * dst_d = dst->data; + const int64_t * row_indices = nullptr; + ggml_type dst_type = dst->type; + int set_rows_stride = 0; + + if (set_rows != nullptr) { + GGML_ASSERT(forward); + dst_d = set_rows->data; + row_indices = (const int64_t *) set_rows->src[1]->data; + dst_type = set_rows->type; + set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type); + } cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - GGML_ASSERT(src0->type == dst->type); + // When not fused, src0 and dst types must match + // When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16 + GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16)); const int64_t ne00 = src0->ne[0]; // head dims const int64_t ne01 = src0->ne[1]; // num heads @@ -404,14 +487,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) // compute if (is_neox) { - if (src0->type == GGML_TYPE_F32) { - rope_neox_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); - } else if (src0->type == GGML_TYPE_F16) { - rope_neox_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { + rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, row_indices, set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { + rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, row_indices, set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { + rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, + pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, row_indices, set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } @@ -440,14 +527,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) GGML_ABORT("fatal error"); } } else { - if (src0->type == GGML_TYPE_F32) { - rope_norm_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); - } else if (src0->type == GGML_TYPE_F16) { - rope_norm_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); + if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { + rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, row_indices, set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { + rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, + nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, row_indices, set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { + rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, + pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, + freq_factors, row_indices, set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } @@ -461,3 +552,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_rope_impl(ctx, dst); } + +void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) { + ggml_cuda_op_rope_impl(ctx, rope, set_rows); +} diff --git a/ggml/src/ggml-cuda/rope.cuh b/ggml/src/ggml-cuda/rope.cuh index 9139f3b220df7..72af086cd1b42 100644 --- a/ggml/src/ggml-cuda/rope.cuh +++ b/ggml/src/ggml-cuda/rope.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index f9751b3183694..90e4feadbf5b7 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1592,9 +1592,10 @@ ggml_tensor * llm_graph_context::build_attn( int il) const { // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache ggml_build_forward_expand(gf, q_cur); - ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); const auto * mctx_cur = inp->mctx;