@@ -524,15 +524,13 @@ struct vk_op_pool2d_push_constants {
524524 int32_t p0; int32_t p1;
525525};
526526
527-
528527struct vk_op_rwkv_wkv6_push_constants {
529- uint32_t B; // Batch size (原n_seqs)
530- uint32_t T; // Sequence length
531- uint32_t C; // Total channels
532- uint32_t H; // Number of heads (原HEADS)
528+ uint32_t B;
529+ uint32_t T;
530+ uint32_t C;
531+ uint32_t H;
533532};
534533
535-
536534// Allow pre-recording command buffers
537535struct vk_staging_memcpy {
538536 vk_staging_memcpy (void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1952,19 +1950,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
19521950
19531951 ggml_vk_create_pipeline (device, device->pipeline_pool2d_f32 , " pool2d_f32" , pool2d_f32_len, pool2d_f32_data, " main" , 2 , sizeof (vk_op_pool2d_push_constants), {512 , 1 , 1 }, {}, 1 );
19541952
1955- ggml_vk_create_pipeline (
1956- device,
1957- device->pipeline_rwkv_wkv6_f32 ,
1958- " rwkv_wkv6_f32" ,
1959- rwkv_wkv6_f32_len,
1960- rwkv_wkv6_f32_data,
1961- " main" ,
1962- 7 ,
1963- sizeof (vk_op_rwkv_wkv6_push_constants),
1964- {1 , 1 , 1 }, // work group
1965- {device->subgroup_size },
1966- 1
1967- );
1953+ ggml_vk_create_pipeline (device, device->pipeline_rwkv_wkv6_f32 , " rwkv_wkv6_f32" , rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, " main" , 7 , sizeof (vk_op_rwkv_wkv6_push_constants), {1 , 1 , 1 }, {device->subgroup_size }, 1 );
19681954
19691955 for (auto &c : compiles) {
19701956 c.wait ();
@@ -5348,28 +5334,14 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
53485334 }, dryrun);
53495335}
53505336
5337+ static void ggml_vk_op_f32_rwkv6 (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false ) {
5338+ const ggml_tensor * k = dst->src [0 ];
5339+ const ggml_tensor * v = dst->src [1 ];
5340+ const ggml_tensor * r = dst->src [2 ];
5341+ const ggml_tensor * tf = dst->src [3 ];
5342+ const ggml_tensor * td = dst->src [4 ];
5343+ const ggml_tensor * state = dst->src [5 ];
53515344
5352-
5353- template <typename PC>
5354- static void ggml_vk_op_f32_rwkv6 (
5355- ggml_backend_vk_context * ctx,
5356- vk_context& subctx,
5357- ggml_tensor * dst,
5358- const PC&& pc,
5359- bool dryrun = false ) {
5360-
5361- // Get source tensors
5362- const ggml_tensor * k = dst->src [0 ]; // keys
5363- const ggml_tensor * v = dst->src [1 ]; // values
5364- const ggml_tensor * r = dst->src [2 ]; // reset gates
5365- const ggml_tensor * tf = dst->src [3 ]; // time first
5366- const ggml_tensor * td = dst->src [4 ]; // time decay
5367- const ggml_tensor * state = dst->src [5 ]; // states
5368-
5369- VK_LOG_DEBUG (" ggml_vk_op_f32_rwkv6(" << k << " , " << v << " , " << r << " , "
5370- << tf << " , " << td << " , " << state << " , " << dst << " )" );
5371-
5372- // Verify input types
53735345 GGML_ASSERT (!ggml_is_quantized (k->type ));
53745346 GGML_ASSERT (!ggml_is_quantized (v->type ));
53755347 GGML_ASSERT (!ggml_is_quantized (r->type ));
@@ -5378,7 +5350,6 @@ static void ggml_vk_op_f32_rwkv6(
53785350 GGML_ASSERT (!ggml_is_quantized (state->type ));
53795351 GGML_ASSERT (dst->buffer != nullptr );
53805352
5381- // Get pipeline
53825353 vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
53835354 GGML_ASSERT (pipeline != nullptr );
53845355
@@ -5387,7 +5358,6 @@ static void ggml_vk_op_f32_rwkv6(
53875358 return ;
53885359 }
53895360
5390- // Get buffer contexts
53915361 ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer ->context ;
53925362 ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer ->context ;
53935363 ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer ->context ;
@@ -5396,7 +5366,6 @@ static void ggml_vk_op_f32_rwkv6(
53965366 ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer ->context ;
53975367 ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer ->context ;
53985368
5399- // Get device buffers
54005369 vk_buffer d_D = dst_buf_ctx->dev_buffer ;
54015370 vk_buffer d_K = k_buf_ctx->dev_buffer ;
54025371 vk_buffer d_V = v_buf_ctx->dev_buffer ;
@@ -5405,7 +5374,6 @@ static void ggml_vk_op_f32_rwkv6(
54055374 vk_buffer d_TD = td_buf_ctx->dev_buffer ;
54065375 vk_buffer d_State = state_buf_ctx->dev_buffer ;
54075376
5408- // Calculate buffer offsets
54095377 const uint64_t k_offset = vk_tensor_offset (k);
54105378 const uint64_t v_offset = vk_tensor_offset (v);
54115379 const uint64_t r_offset = vk_tensor_offset (r);
@@ -5414,7 +5382,6 @@ static void ggml_vk_op_f32_rwkv6(
54145382 const uint64_t state_offset = vk_tensor_offset (state);
54155383 const uint64_t dst_offset = vk_tensor_offset (dst);
54165384
5417- // Calculate buffer sizes
54185385 const uint64_t k_size = ggml_nbytes (k);
54195386 const uint64_t v_size = ggml_nbytes (v);
54205387 const uint64_t r_size = ggml_nbytes (r);
@@ -5423,14 +5390,12 @@ static void ggml_vk_op_f32_rwkv6(
54235390 const uint64_t state_size = ggml_nbytes (state);
54245391 const uint64_t dst_size = ggml_nbytes (dst);
54255392
5426- // Set work elements based on tensor dimensions
54275393 std::array<uint32_t , 3 > elements = {
5428- (uint32_t )(pc.B * pc.H ), // B * H workgroups
5429- 1 , // 每个workgroup 64个线程
5394+ (uint32_t )(pc.B * pc.H ),
5395+ 1 ,
54305396 1
54315397 };
54325398
5433- // Synchronize buffers and dispatch compute pipeline
54345399 ggml_vk_sync_buffers (subctx);
54355400 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, {
54365401 vk_subbuffer{ d_K, k_offset, k_size },
@@ -5440,35 +5405,27 @@ static void ggml_vk_op_f32_rwkv6(
54405405 vk_subbuffer{ d_TD, td_offset, td_size },
54415406 vk_subbuffer{ d_State, state_offset, state_size },
54425407 vk_subbuffer{ d_D, dst_offset, dst_size }
5443- }, sizeof (PC), &pc, elements);
5444- }
5445-
5446- static void ggml_vk_rwkv_wkv6 (
5447- ggml_backend_vk_context * ctx,
5448- vk_context& subctx,
5449- ggml_tensor * dst,
5450- bool dryrun = false ) {
5451-
5452- // Extract dimensions from tensors
5453- const size_t T = dst->src [0 ]->ne [3 ]; // Sequence length
5454- const size_t C = dst->ne [0 ]; // Channel dimension
5455- const size_t HEADS = dst->src [0 ]->ne [2 ]; // Number of heads
5456- const size_t n_seqs = dst->src [5 ]->ne [1 ]; // Batch size
5457-
5458- // Call implementation with push constants
5459- ggml_vk_op_f32_rwkv6<vk_op_rwkv_wkv6_push_constants>(
5408+ }, sizeof (vk_op_rwkv_wkv6_push_constants), &pc, elements);
5409+ }
5410+
5411+ static void ggml_vk_rwkv_wkv6 (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false ) {
5412+ const size_t seq_length = dst->src [0 ]->ne [3 ];
5413+ const size_t n_embed = dst->ne [0 ];
5414+ const size_t n_heads = dst->src [0 ]->ne [2 ];
5415+ const size_t n_seqs = dst->src [5 ]->ne [1 ];
5416+
5417+ ggml_vk_op_f32_rwkv6 (
54605418 ctx, subctx, dst,
54615419 {
5462- (uint32_t )n_seqs, // B
5463- (uint32_t )T, // T
5464- (uint32_t )C, // C
5465- (uint32_t )HEADS, // H
5420+ (uint32_t )n_seqs,
5421+ (uint32_t )seq_length,
5422+ (uint32_t )n_embed,
5423+ (uint32_t )n_heads,
54665424 },
54675425 dryrun
54685426 );
54695427}
54705428
5471-
54725429static void ggml_vk_concat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
54735430 int * op_params = (int *)dst->op_params ;
54745431
@@ -8344,10 +8301,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
83448301 } else if (tensor->op == GGML_OP_LEAKY_RELU) {
83458302 const float * op_params = (const float *)tensor->op_params ;
83468303 tensor_clone = ggml_leaky_relu (ggml_ctx, src0_clone, op_params[0 ], false );
8347- } else if (tensor->op == GGML_OP_RWKV_WKV6) {
8304+ } else if (tensor->op == GGML_OP_RWKV_WKV6) {
83488305 tensor_clone = ggml_rwkv_wkv6 (ggml_ctx, tensor->src [0 ], tensor->src [1 ], tensor->src [2 ], tensor->src [3 ],
83498306 tensor->src [4 ], tensor->src [5 ]);
8350- }
8307+ }
83518308 else {
83528309 std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
83538310 GGML_ABORT (" fatal error" );
0 commit comments