@@ -240,6 +240,7 @@ struct vk_device_struct {
240240 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
241241 vk_pipeline pipeline_timestep_embedding_f32;
242242 vk_pipeline pipeline_pool2d_f32;
243+ vk_pipeline pipeline_rwkv_wkv6_f32;
243244
244245 // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
245246 vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2 ][2 ][2 ];
@@ -523,6 +524,15 @@ struct vk_op_pool2d_push_constants {
523524 int32_t p0; int32_t p1;
524525};
525526
527+
528+ struct vk_op_rwkv_wkv6_push_constants {
529+ uint32_t B; // Batch size (原n_seqs)
530+ uint32_t T; // Sequence length
531+ uint32_t C; // Total channels
532+ uint32_t H; // Number of heads (原HEADS)
533+ };
534+
535+
526536// Allow pre-recording command buffers
527537struct vk_staging_memcpy {
528538 vk_staging_memcpy (void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1942,6 +1952,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
19421952
19431953 ggml_vk_create_pipeline (device, device->pipeline_pool2d_f32 , " pool2d_f32" , pool2d_f32_len, pool2d_f32_data, " main" , 2 , sizeof (vk_op_pool2d_push_constants), {512 , 1 , 1 }, {}, 1 );
19441954
1955+ ggml_vk_create_pipeline (
1956+ device,
1957+ device->pipeline_rwkv_wkv6_f32 ,
1958+ " rwkv_wkv6_f32" ,
1959+ rwkv_wkv6_f32_len,
1960+ rwkv_wkv6_f32_data,
1961+ " main" ,
1962+ 7 ,
1963+ sizeof (vk_op_rwkv_wkv6_push_constants),
1964+ {64 , 1 , 1 }, // work group
1965+ {device->subgroup_size },
1966+ 1
1967+ );
1968+
19451969 for (auto &c : compiles) {
19461970 c.wait ();
19471971 }
@@ -4917,6 +4941,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
49174941 return ctx->device ->pipeline_pool2d_f32 ;
49184942 }
49194943 return nullptr ;
4944+ case GGML_OP_RWKV_WKV6:
4945+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4946+ return ctx->device ->pipeline_rwkv_wkv6_f32 ;
4947+ }
4948+ return nullptr ;
49204949 case GGML_OP_LEAKY_RELU:
49214950 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
49224951 return ctx->device ->pipeline_leaky_relu_f32 ;
@@ -5319,6 +5348,127 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
53195348 }, dryrun);
53205349}
53215350
5351+
5352+
5353+ template <typename PC>
5354+ static void ggml_vk_op_f32_rwkv6 (
5355+ ggml_backend_vk_context * ctx,
5356+ vk_context& subctx,
5357+ ggml_tensor * dst,
5358+ const PC&& pc,
5359+ bool dryrun = false ) {
5360+
5361+ // Get source tensors
5362+ const ggml_tensor * k = dst->src [0 ]; // keys
5363+ const ggml_tensor * v = dst->src [1 ]; // values
5364+ const ggml_tensor * r = dst->src [2 ]; // reset gates
5365+ const ggml_tensor * tf = dst->src [3 ]; // time first
5366+ const ggml_tensor * td = dst->src [4 ]; // time decay
5367+ const ggml_tensor * state = dst->src [5 ]; // states
5368+
5369+ VK_LOG_DEBUG (" ggml_vk_op_f32_rwkv6(" << k << " , " << v << " , " << r << " , "
5370+ << tf << " , " << td << " , " << state << " , " << dst << " )" );
5371+
5372+ // Verify input types
5373+ GGML_ASSERT (!ggml_is_quantized (k->type ));
5374+ GGML_ASSERT (!ggml_is_quantized (v->type ));
5375+ GGML_ASSERT (!ggml_is_quantized (r->type ));
5376+ GGML_ASSERT (!ggml_is_quantized (tf->type ));
5377+ GGML_ASSERT (!ggml_is_quantized (td->type ));
5378+ GGML_ASSERT (!ggml_is_quantized (state->type ));
5379+ GGML_ASSERT (dst->buffer != nullptr );
5380+
5381+ // Get pipeline
5382+ vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
5383+ GGML_ASSERT (pipeline != nullptr );
5384+
5385+ if (dryrun) {
5386+ ggml_pipeline_request_descriptor_sets (ctx->device , pipeline, 1 );
5387+ return ;
5388+ }
5389+
5390+ // Get buffer contexts
5391+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer ->context ;
5392+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer ->context ;
5393+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer ->context ;
5394+ ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer ->context ;
5395+ ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer ->context ;
5396+ ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer ->context ;
5397+ ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer ->context ;
5398+
5399+ // Get device buffers
5400+ vk_buffer d_D = dst_buf_ctx->dev_buffer ;
5401+ vk_buffer d_K = k_buf_ctx->dev_buffer ;
5402+ vk_buffer d_V = v_buf_ctx->dev_buffer ;
5403+ vk_buffer d_R = r_buf_ctx->dev_buffer ;
5404+ vk_buffer d_TF = tf_buf_ctx->dev_buffer ;
5405+ vk_buffer d_TD = td_buf_ctx->dev_buffer ;
5406+ vk_buffer d_State = state_buf_ctx->dev_buffer ;
5407+
5408+ // Calculate buffer offsets
5409+ const uint64_t k_offset = vk_tensor_offset (k);
5410+ const uint64_t v_offset = vk_tensor_offset (v);
5411+ const uint64_t r_offset = vk_tensor_offset (r);
5412+ const uint64_t tf_offset = vk_tensor_offset (tf);
5413+ const uint64_t td_offset = vk_tensor_offset (td);
5414+ const uint64_t state_offset = vk_tensor_offset (state);
5415+ const uint64_t dst_offset = vk_tensor_offset (dst);
5416+
5417+ // Calculate buffer sizes
5418+ const uint64_t k_size = ggml_nbytes (k);
5419+ const uint64_t v_size = ggml_nbytes (v);
5420+ const uint64_t r_size = ggml_nbytes (r);
5421+ const uint64_t tf_size = ggml_nbytes (tf);
5422+ const uint64_t td_size = ggml_nbytes (td);
5423+ const uint64_t state_size = ggml_nbytes (state);
5424+ const uint64_t dst_size = ggml_nbytes (dst);
5425+
5426+ // Set work elements based on tensor dimensions
5427+ std::array<uint32_t , 3 > elements = {
5428+ (uint32_t )(pc.B *pc.H ), // B * H workgroups
5429+ 1 , // 每个workgroup 64个线程
5430+ 1
5431+ };
5432+
5433+ // Synchronize buffers and dispatch compute pipeline
5434+ ggml_vk_sync_buffers (subctx);
5435+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, {
5436+ vk_subbuffer{ d_K, k_offset, k_size },
5437+ vk_subbuffer{ d_V, v_offset, v_size },
5438+ vk_subbuffer{ d_R, r_offset, r_size },
5439+ vk_subbuffer{ d_TF, tf_offset, tf_size },
5440+ vk_subbuffer{ d_TD, td_offset, td_size },
5441+ vk_subbuffer{ d_State, state_offset, state_size },
5442+ vk_subbuffer{ d_D, dst_offset, dst_size }
5443+ }, sizeof (PC), &pc, elements);
5444+ }
5445+
5446+ static void ggml_vk_rwkv_wkv6 (
5447+ ggml_backend_vk_context * ctx,
5448+ vk_context& subctx,
5449+ ggml_tensor * dst,
5450+ bool dryrun = false ) {
5451+
5452+ // Extract dimensions from tensors
5453+ const size_t T = dst->src [0 ]->ne [3 ]; // Sequence length
5454+ const size_t C = dst->ne [0 ]; // Channel dimension
5455+ const size_t HEADS = dst->src [0 ]->ne [2 ]; // Number of heads
5456+ const size_t n_seqs = dst->src [5 ]->ne [1 ]; // Batch size
5457+
5458+ // Call implementation with push constants
5459+ ggml_vk_op_f32_rwkv6<vk_op_rwkv_wkv6_push_constants>(
5460+ ctx, subctx, dst,
5461+ {
5462+ (uint32_t )n_seqs, // B
5463+ (uint32_t )T, // T
5464+ (uint32_t )C, // C
5465+ (uint32_t )HEADS, // H
5466+ },
5467+ dryrun
5468+ );
5469+ }
5470+
5471+
53225472static void ggml_vk_concat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
53235473 int * op_params = (int *)dst->op_params ;
53245474
@@ -6464,6 +6614,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
64646614 case GGML_OP_IM2COL:
64656615 case GGML_OP_TIMESTEP_EMBEDDING:
64666616 case GGML_OP_POOL_2D:
6617+ case GGML_OP_RWKV_WKV6:
64676618 case GGML_OP_LEAKY_RELU:
64686619 case GGML_OP_FLASH_ATTN_EXT:
64696620 break ;
@@ -6663,6 +6814,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
66636814 case GGML_OP_FLASH_ATTN_EXT:
66646815 ggml_vk_flash_attn (ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
66656816
6817+ break ;
6818+
6819+ case GGML_OP_RWKV_WKV6:
6820+ ggml_vk_rwkv_wkv6 (ctx, compute_ctx, node, dryrun);
6821+
66666822 break ;
66676823 default :
66686824 return false ;
@@ -6743,6 +6899,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
67436899 case GGML_OP_IM2COL:
67446900 case GGML_OP_TIMESTEP_EMBEDDING:
67456901 case GGML_OP_POOL_2D:
6902+ case GGML_OP_RWKV_WKV6:
67466903 case GGML_OP_LEAKY_RELU:
67476904 case GGML_OP_REPEAT:
67486905 buf = tensor->buffer ;
@@ -7610,6 +7767,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
76107767 case GGML_OP_IM2COL:
76117768 case GGML_OP_TIMESTEP_EMBEDDING:
76127769 case GGML_OP_POOL_2D:
7770+ case GGML_OP_RWKV_WKV6:
76137771 case GGML_OP_LEAKY_RELU:
76147772 return true ;
76157773 default :
@@ -8186,7 +8344,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
81868344 } else if (tensor->op == GGML_OP_LEAKY_RELU) {
81878345 const float * op_params = (const float *)tensor->op_params ;
81888346 tensor_clone = ggml_leaky_relu (ggml_ctx, src0_clone, op_params[0 ], false );
8189- } else {
8347+ }
8348+ // else if (tensor->op == GGML_OP_RWKV_WKV6) {
8349+ // tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8350+ // tensor->src[4], tensor->src[5]);
8351+ // }
8352+ else {
81908353 std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
81918354 GGML_ABORT (" fatal error" );
81928355 }
0 commit comments