From 4651f5e2f29b32e24c69c511d0bacb14d29e6008 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Sat, 2 Nov 2024 01:45:27 +1100 Subject: [PATCH 1/6] rwkv_wkv6 vulkan shader --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 165 +++++++++++++++++- .../ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp | 96 ++++++++++ 2 files changed, 260 insertions(+), 1 deletion(-) create mode 100644 ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a8ae58ee2ce85..e103e67f76abf 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -240,6 +240,7 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_rwkv_wkv6_f32; // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned} vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2]; @@ -523,6 +524,15 @@ struct vk_op_pool2d_push_constants { int32_t p0; int32_t p1; }; + +struct vk_op_rwkv_wkv6_push_constants { + uint32_t B; // Batch size (原n_seqs) + uint32_t T; // Sequence length + uint32_t C; // Total channels + uint32_t H; // Number of heads (原HEADS) +}; + + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1942,6 +1952,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline( + device, + device->pipeline_rwkv_wkv6_f32, + "rwkv_wkv6_f32", + rwkv_wkv6_f32_len, + rwkv_wkv6_f32_data, + "main", + 7, + sizeof(vk_op_rwkv_wkv6_push_constants), + {64, 1, 1}, // work group + {device->subgroup_size}, + 1 + ); + for (auto &c : compiles) { c.wait(); } @@ -4917,6 +4941,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_pool2d_f32; } return nullptr; + case GGML_OP_RWKV_WKV6: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv6_f32; + } + return nullptr; case GGML_OP_LEAKY_RELU: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_leaky_relu_f32; @@ -5319,6 +5348,127 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } + + +template +static void ggml_vk_op_f32_rwkv6( + ggml_backend_vk_context * ctx, + vk_context& subctx, + ggml_tensor * dst, + const PC&& pc, + bool dryrun = false) { + + // Get source tensors + const ggml_tensor * k = dst->src[0]; // keys + const ggml_tensor * v = dst->src[1]; // values + const ggml_tensor * r = dst->src[2]; // reset gates + const ggml_tensor * tf = dst->src[3]; // time first + const ggml_tensor * td = dst->src[4]; // time decay + const ggml_tensor * state = dst->src[5]; // states + + VK_LOG_DEBUG("ggml_vk_op_f32_rwkv6(" << k << ", " << v << ", " << r << ", " + << tf << ", " << td << ", " << state << ", " << dst << ")"); + + // Verify input types + GGML_ASSERT(!ggml_is_quantized(k->type)); + GGML_ASSERT(!ggml_is_quantized(v->type)); + GGML_ASSERT(!ggml_is_quantized(r->type)); + GGML_ASSERT(!ggml_is_quantized(tf->type)); + GGML_ASSERT(!ggml_is_quantized(td->type)); + GGML_ASSERT(!ggml_is_quantized(state->type)); + GGML_ASSERT(dst->buffer != nullptr); + + // Get pipeline + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1); + return; + } + + // Get buffer contexts + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context; + ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context; + ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; + ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; + + // Get device buffers + vk_buffer d_D = dst_buf_ctx->dev_buffer; + vk_buffer d_K = k_buf_ctx->dev_buffer; + vk_buffer d_V = v_buf_ctx->dev_buffer; + vk_buffer d_R = r_buf_ctx->dev_buffer; + vk_buffer d_TF = tf_buf_ctx->dev_buffer; + vk_buffer d_TD = td_buf_ctx->dev_buffer; + vk_buffer d_State = state_buf_ctx->dev_buffer; + + // Calculate buffer offsets + const uint64_t k_offset = vk_tensor_offset(k); + const uint64_t v_offset = vk_tensor_offset(v); + const uint64_t r_offset = vk_tensor_offset(r); + const uint64_t tf_offset = vk_tensor_offset(tf); + const uint64_t td_offset = vk_tensor_offset(td); + const uint64_t state_offset = vk_tensor_offset(state); + const uint64_t dst_offset = vk_tensor_offset(dst); + + // Calculate buffer sizes + const uint64_t k_size = ggml_nbytes(k); + const uint64_t v_size = ggml_nbytes(v); + const uint64_t r_size = ggml_nbytes(r); + const uint64_t tf_size = ggml_nbytes(tf); + const uint64_t td_size = ggml_nbytes(td); + const uint64_t state_size = ggml_nbytes(state); + const uint64_t dst_size = ggml_nbytes(dst); + + // Set work elements based on tensor dimensions + std::array elements = { + (uint32_t)(pc.B*pc.H), // B * H workgroups + 1, // 每个workgroup 64个线程 + 1 + }; + + // Synchronize buffers and dispatch compute pipeline + ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_K, k_offset, k_size }, + vk_subbuffer{ d_V, v_offset, v_size }, + vk_subbuffer{ d_R, r_offset, r_size }, + vk_subbuffer{ d_TF, tf_offset, tf_size }, + vk_subbuffer{ d_TD, td_offset, td_size }, + vk_subbuffer{ d_State, state_offset, state_size }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, sizeof(PC), &pc, elements); +} + +static void ggml_vk_rwkv_wkv6( + ggml_backend_vk_context * ctx, + vk_context& subctx, + ggml_tensor * dst, + bool dryrun = false) { + + // Extract dimensions from tensors + const size_t T = dst->src[0]->ne[3]; // Sequence length + const size_t C = dst->ne[0]; // Channel dimension + const size_t HEADS = dst->src[0]->ne[2]; // Number of heads + const size_t n_seqs = dst->src[5]->ne[1]; // Batch size + + // Call implementation with push constants + ggml_vk_op_f32_rwkv6( + ctx, subctx, dst, + { + (uint32_t)n_seqs, // B + (uint32_t)T, // T + (uint32_t)C, // C + (uint32_t)HEADS, // H + }, + dryrun + ); +} + + static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -6464,6 +6614,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_FLASH_ATTN_EXT: break; @@ -6663,6 +6814,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod case GGML_OP_FLASH_ATTN_EXT: ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun); + break; + + case GGML_OP_RWKV_WKV6: + ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + break; default: return false; @@ -6743,6 +6899,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: case GGML_OP_REPEAT: buf = tensor->buffer; @@ -7610,6 +7767,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_IM2COL: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: case GGML_OP_LEAKY_RELU: return true; default: @@ -8186,7 +8344,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); - } else { + } + // else if (tensor->op == GGML_OP_RWKV_WKV6) { + // tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], + // tensor->src[4], tensor->src[5]); + // } + else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp new file mode 100644 index 0000000000000..6465f2da92449 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp @@ -0,0 +1,96 @@ +#version 450 + + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; // Batch size + uint T; // Sequence length + uint C; // Total number of channels + uint H; // Number of heads +}; + +layout(set = 0, binding = 0) readonly buffer KBuf { float k[]; }; +layout(set = 0, binding = 1) readonly buffer VBuf { float v[]; }; +layout(set = 0, binding = 2) readonly buffer RBuf { float r[]; }; +layout(set = 0, binding = 3) readonly buffer TimeFBuf { float tf[]; }; +layout(set = 0, binding = 4) readonly buffer TimeDBuf { float td[]; }; +layout(set = 0, binding = 5) readonly buffer StateBuf { float state_in[]; }; +layout(set = 0, binding = 6) buffer DstBuf { float dst[]; }; + +shared float _k[64], _r[64], _tf[64], _td[64]; + +void main() { + const uint head_size = 64; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (tid >= head_size || batch_id >= B || head_id >= H) { + return; + } + + // Load state + float state[64]; // Use fixed size matching head_size + for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + _k[tid] = 0.0; + _r[tid] = 0.0; + _td[tid] = 0.0; + barrier(); + _tf[tid] = tf[head_id * head_size + tid]; + barrier(); + + + // Main loop + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + barrier(); + + const float v_val = v[t]; + float y = 0.0; + + for (uint j = 0; j < head_size; j += 4) { + // Load values in blocks of 4 + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + // Compute kv products + vec4 kv = k_vec * v_val; + + // Accumulate results + vec4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + // Update state + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + // Write back state + for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} \ No newline at end of file From 77fe4fd982776afa3909e53857d502e557f8416a Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Fri, 13 Dec 2024 17:19:23 +0800 Subject: [PATCH 2/6] RWKV_WKV6 Vulkan op tests passed Signed-off-by: Molly Sophia --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 9 ++++----- .../ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp | 2 ++ .../vulkan-shaders/{rwkv_wkv6.comp => wkv6.comp} | 0 3 files changed, 6 insertions(+), 5 deletions(-) rename ggml/src/ggml-vulkan/vulkan-shaders/{rwkv_wkv6.comp => wkv6.comp} (100%) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index e103e67f76abf..da11e88cde17c 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1961,7 +1961,7 @@ static void ggml_vk_load_shaders(vk_device& device) { "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), - {64, 1, 1}, // work group + {1, 1, 1}, // work group {device->subgroup_size}, 1 ); @@ -8344,11 +8344,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); + } else if (tensor->op == GGML_OP_RWKV_WKV6) { + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], + tensor->src[4], tensor->src[5]); } - // else if (tensor->op == GGML_OP_RWKV_WKV6) { - // tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], - // tensor->src[4], tensor->src[5]); - // } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index c48a228aef65d..eff60f3c3badb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -479,6 +479,8 @@ void process_shaders() { string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}, {"F_TYPE", "float"}, {"S_TYPE", "float"}})); + for (auto &c : compiles) { c.wait(); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp similarity index 100% rename from ggml/src/ggml-vulkan/vulkan-shaders/rwkv_wkv6.comp rename to ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp From 60bbd4ebf174986fb1b918310ab65074847d79ed Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Fri, 13 Dec 2024 17:43:08 +0800 Subject: [PATCH 3/6] Apply code format changes Signed-off-by: Molly Sophia --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 103 +++++------------- .../vulkan-shaders/vulkan-shaders-gen.cpp | 2 +- ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp | 75 ++++++------- 3 files changed, 63 insertions(+), 117 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index da11e88cde17c..4c0fb4d46dd24 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -524,15 +524,13 @@ struct vk_op_pool2d_push_constants { int32_t p0; int32_t p1; }; - struct vk_op_rwkv_wkv6_push_constants { - uint32_t B; // Batch size (原n_seqs) - uint32_t T; // Sequence length - uint32_t C; // Total channels - uint32_t H; // Number of heads (原HEADS) + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; }; - // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1952,19 +1950,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline( - device, - device->pipeline_rwkv_wkv6_f32, - "rwkv_wkv6_f32", - rwkv_wkv6_f32_len, - rwkv_wkv6_f32_data, - "main", - 7, - sizeof(vk_op_rwkv_wkv6_push_constants), - {1, 1, 1}, // work group - {device->subgroup_size}, - 1 - ); + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); for (auto &c : compiles) { c.wait(); @@ -5348,28 +5334,14 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const }, dryrun); } +static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * k = dst->src[0]; + const ggml_tensor * v = dst->src[1]; + const ggml_tensor * r = dst->src[2]; + const ggml_tensor * tf = dst->src[3]; + const ggml_tensor * td = dst->src[4]; + const ggml_tensor * state = dst->src[5]; - -template -static void ggml_vk_op_f32_rwkv6( - ggml_backend_vk_context * ctx, - vk_context& subctx, - ggml_tensor * dst, - const PC&& pc, - bool dryrun = false) { - - // Get source tensors - const ggml_tensor * k = dst->src[0]; // keys - const ggml_tensor * v = dst->src[1]; // values - const ggml_tensor * r = dst->src[2]; // reset gates - const ggml_tensor * tf = dst->src[3]; // time first - const ggml_tensor * td = dst->src[4]; // time decay - const ggml_tensor * state = dst->src[5]; // states - - VK_LOG_DEBUG("ggml_vk_op_f32_rwkv6(" << k << ", " << v << ", " << r << ", " - << tf << ", " << td << ", " << state << ", " << dst << ")"); - - // Verify input types GGML_ASSERT(!ggml_is_quantized(k->type)); GGML_ASSERT(!ggml_is_quantized(v->type)); GGML_ASSERT(!ggml_is_quantized(r->type)); @@ -5378,7 +5350,6 @@ static void ggml_vk_op_f32_rwkv6( GGML_ASSERT(!ggml_is_quantized(state->type)); GGML_ASSERT(dst->buffer != nullptr); - // Get pipeline vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6); GGML_ASSERT(pipeline != nullptr); @@ -5387,7 +5358,6 @@ static void ggml_vk_op_f32_rwkv6( return; } - // Get buffer contexts ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; @@ -5396,7 +5366,6 @@ static void ggml_vk_op_f32_rwkv6( ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; - // Get device buffers vk_buffer d_D = dst_buf_ctx->dev_buffer; vk_buffer d_K = k_buf_ctx->dev_buffer; vk_buffer d_V = v_buf_ctx->dev_buffer; @@ -5405,7 +5374,6 @@ static void ggml_vk_op_f32_rwkv6( vk_buffer d_TD = td_buf_ctx->dev_buffer; vk_buffer d_State = state_buf_ctx->dev_buffer; - // Calculate buffer offsets const uint64_t k_offset = vk_tensor_offset(k); const uint64_t v_offset = vk_tensor_offset(v); const uint64_t r_offset = vk_tensor_offset(r); @@ -5414,7 +5382,6 @@ static void ggml_vk_op_f32_rwkv6( const uint64_t state_offset = vk_tensor_offset(state); const uint64_t dst_offset = vk_tensor_offset(dst); - // Calculate buffer sizes const uint64_t k_size = ggml_nbytes(k); const uint64_t v_size = ggml_nbytes(v); const uint64_t r_size = ggml_nbytes(r); @@ -5423,14 +5390,12 @@ static void ggml_vk_op_f32_rwkv6( const uint64_t state_size = ggml_nbytes(state); const uint64_t dst_size = ggml_nbytes(dst); - // Set work elements based on tensor dimensions std::array elements = { - (uint32_t)(pc.B*pc.H), // B * H workgroups - 1, // 每个workgroup 64个线程 + (uint32_t)(pc.B * pc.H), + 1, 1 }; - // Synchronize buffers and dispatch compute pipeline ggml_vk_sync_buffers(subctx); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_K, k_offset, k_size }, @@ -5440,35 +5405,27 @@ static void ggml_vk_op_f32_rwkv6( vk_subbuffer{ d_TD, td_offset, td_size }, vk_subbuffer{ d_State, state_offset, state_size }, vk_subbuffer{ d_D, dst_offset, dst_size } - }, sizeof(PC), &pc, elements); -} - -static void ggml_vk_rwkv_wkv6( - ggml_backend_vk_context * ctx, - vk_context& subctx, - ggml_tensor * dst, - bool dryrun = false) { - - // Extract dimensions from tensors - const size_t T = dst->src[0]->ne[3]; // Sequence length - const size_t C = dst->ne[0]; // Channel dimension - const size_t HEADS = dst->src[0]->ne[2]; // Number of heads - const size_t n_seqs = dst->src[5]->ne[1]; // Batch size - - // Call implementation with push constants - ggml_vk_op_f32_rwkv6( + }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements); +} + +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[3]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[2]; + const size_t n_seqs = dst->src[5]->ne[1]; + + ggml_vk_op_f32_rwkv6( ctx, subctx, dst, { - (uint32_t)n_seqs, // B - (uint32_t)T, // T - (uint32_t)C, // C - (uint32_t)HEADS, // H + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, }, dryrun ); } - static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { int * op_params = (int *)dst->op_params; @@ -8344,10 +8301,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false); - } else if (tensor->op == GGML_OP_RWKV_WKV6) { + } else if (tensor->op == GGML_OP_RWKV_WKV6) { tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor->src[5]); - } + } else { std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index eff60f3c3badb..7a0d7285dcb23 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -479,7 +479,7 @@ void process_shaders() { string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"C_TYPE", "float"}, {"D_TYPE", "float"}, {"E_TYPE", "float"}, {"F_TYPE", "float"}, {"S_TYPE", "float"}})); + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); for (auto &c : compiles) { c.wait(); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp index 6465f2da92449..8beb7ff6e2763 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -1,96 +1,85 @@ #version 450 - -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout(push_constant) uniform Parameters { - uint B; // Batch size - uint T; // Sequence length - uint C; // Total number of channels - uint H; // Number of heads + uint B; + uint T; + uint C; + uint H; }; -layout(set = 0, binding = 0) readonly buffer KBuf { float k[]; }; -layout(set = 0, binding = 1) readonly buffer VBuf { float v[]; }; -layout(set = 0, binding = 2) readonly buffer RBuf { float r[]; }; -layout(set = 0, binding = 3) readonly buffer TimeFBuf { float tf[]; }; -layout(set = 0, binding = 4) readonly buffer TimeDBuf { float td[]; }; -layout(set = 0, binding = 5) readonly buffer StateBuf { float state_in[]; }; -layout(set = 0, binding = 6) buffer DstBuf { float dst[]; }; +layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; +layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; -shared float _k[64], _r[64], _tf[64], _td[64]; +shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; void main() { - const uint head_size = 64; + const uint head_size = BLOCK_SIZE; const uint batch_id = gl_WorkGroupID.x / H; const uint head_id = gl_WorkGroupID.x % H; const uint tid = gl_LocalInvocationID.x; - + const uint state_size = C * head_size; const uint n_seq_tokens = T / B; if (tid >= head_size || batch_id >= B || head_id >= H) { return; } - - // Load state - float state[64]; // Use fixed size matching head_size + + A_TYPE state[BLOCK_SIZE]; for (uint i = 0; i < head_size; i++) { - state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + i * head_size + tid]; } - - _k[tid] = 0.0; - _r[tid] = 0.0; - _td[tid] = 0.0; + barrier(); _tf[tid] = tf[head_id * head_size + tid]; barrier(); - - // Main loop const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; - + for (uint t = start_t; t < end_t; t += C) { barrier(); _k[tid] = k[t]; _r[tid] = r[t]; _td[tid] = td[t]; barrier(); - - const float v_val = v[t]; - float y = 0.0; - + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + for (uint j = 0; j < head_size; j += 4) { - // Load values in blocks of 4 vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); - - // Compute kv products + vec4 kv = k_vec * v_val; - - // Accumulate results + vec4 temp = tf_vec * kv + s_vec; y += dot(r_vec, temp); - - // Update state + s_vec = s_vec * td_vec + kv; state[j] = s_vec.x; state[j+1] = s_vec.y; state[j+2] = s_vec.z; state[j+3] = s_vec.w; } - + dst[t] = y; } - - // Write back state + for (uint i = 0; i < head_size; i++) { - dst[T * C + batch_id * state_size + head_id * head_size * head_size + dst[T * C + batch_id * state_size + head_id * head_size * head_size + i * head_size + tid] = state[i]; } -} \ No newline at end of file +} From 6ea605ddfcdb13e96651b08faa252e62e5c55733 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Mon, 16 Dec 2024 17:35:44 +0800 Subject: [PATCH 4/6] add [[unroll]] and remove unnecessary conditions --- ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp index 8beb7ff6e2763..35cc6c45f90a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -1,5 +1,7 @@ #version 450 +#extension GL_EXT_control_flow_attributes : require + #define BLOCK_SIZE 64 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; @@ -29,12 +31,12 @@ void main() { const uint state_size = C * head_size; const uint n_seq_tokens = T / B; - if (tid >= head_size || batch_id >= B || head_id >= H) { + if (batch_id >= B || head_id >= H) { return; } A_TYPE state[BLOCK_SIZE]; - for (uint i = 0; i < head_size; i++) { + [[unroll]] for (uint i = 0; i < head_size; i++) { state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + i * head_size + tid]; } @@ -56,7 +58,7 @@ void main() { const A_TYPE v_val = v[t]; A_TYPE y = 0.0; - for (uint j = 0; j < head_size; j += 4) { + [[unroll]] for (uint j = 0; j < head_size; j += 4) { vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); @@ -78,7 +80,7 @@ void main() { dst[t] = y; } - for (uint i = 0; i < head_size; i++) { + [[unroll]] for (uint i = 0; i < head_size; i++) { dst[T * C + batch_id * state_size + head_id * head_size * head_size + i * head_size + tid] = state[i]; } From 353c5f8c7b6a3583b3632d09a1e8893c5f3d2954 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Mon, 16 Dec 2024 18:43:22 +0800 Subject: [PATCH 5/6] add uma support --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 69 +++++++++++++++++++++------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 79a4c6d9a6118..3aa1fc07b9cb3 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5471,21 +5471,58 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context; ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; - vk_buffer d_D = dst_buf_ctx->dev_buffer; - vk_buffer d_K = k_buf_ctx->dev_buffer; - vk_buffer d_V = v_buf_ctx->dev_buffer; - vk_buffer d_R = r_buf_ctx->dev_buffer; - vk_buffer d_TF = tf_buf_ctx->dev_buffer; - vk_buffer d_TD = td_buf_ctx->dev_buffer; - vk_buffer d_State = state_buf_ctx->dev_buffer; - - const uint64_t k_offset = vk_tensor_offset(k); - const uint64_t v_offset = vk_tensor_offset(v); - const uint64_t r_offset = vk_tensor_offset(r); - const uint64_t tf_offset = vk_tensor_offset(tf); - const uint64_t td_offset = vk_tensor_offset(td); - const uint64_t state_offset = vk_tensor_offset(state); - const uint64_t dst_offset = vk_tensor_offset(dst); + ggml_vk_sync_buffers(subctx); + + vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State; + uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset; + bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, k->data, d_K, k_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_offset); + ggml_vk_host_get(ctx->device, r->data, d_R, r_offset); + ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset); + ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset); + ggml_vk_host_get(ctx->device, state->data, d_State, state_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + R_uma = d_R != nullptr; + TF_uma = d_TF != nullptr; + TD_uma = d_TD != nullptr; + STATE_uma = d_State != nullptr; + DST_uma = d_D != nullptr; + } + + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!R_uma) { + d_R = r_buf_ctx->dev_buffer; + r_offset = vk_tensor_offset(r) + r->view_offs; + } + if (!TF_uma) { + d_TF = tf_buf_ctx->dev_buffer; + tf_offset = vk_tensor_offset(tf) + tf->view_offs; + } + if (!TD_uma) { + d_TD = td_buf_ctx->dev_buffer; + td_offset = vk_tensor_offset(td) + td->view_offs; + } + if (!STATE_uma) { + d_State = state_buf_ctx->dev_buffer; + state_offset = vk_tensor_offset(state) + state->view_offs; + } + if (!DST_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } const uint64_t k_size = ggml_nbytes(k); const uint64_t v_size = ggml_nbytes(v); @@ -5501,7 +5538,7 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc 1 }; - ggml_vk_sync_buffers(subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_K, k_offset, k_size }, vk_subbuffer{ d_V, v_offset, v_size }, From aa13d6990566b6cf3b1d3a2f9d976cc035dade80 Mon Sep 17 00:00:00 2001 From: Zhiyuan Li Date: Mon, 16 Dec 2024 19:34:17 +0800 Subject: [PATCH 6/6] fix erros in EditorConfig Checker --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3aa1fc07b9cb3..08944d76c5213 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -5472,7 +5472,7 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context; ggml_vk_sync_buffers(subctx); - + vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State; uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset; bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false; @@ -5538,7 +5538,6 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc 1 }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_K, k_offset, k_size }, vk_subbuffer{ d_V, v_offset, v_size },