@@ -5471,21 +5471,58 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
54715471 ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer ->context ;
54725472 ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer ->context ;
54735473
5474- vk_buffer d_D = dst_buf_ctx->dev_buffer ;
5475- vk_buffer d_K = k_buf_ctx->dev_buffer ;
5476- vk_buffer d_V = v_buf_ctx->dev_buffer ;
5477- vk_buffer d_R = r_buf_ctx->dev_buffer ;
5478- vk_buffer d_TF = tf_buf_ctx->dev_buffer ;
5479- vk_buffer d_TD = td_buf_ctx->dev_buffer ;
5480- vk_buffer d_State = state_buf_ctx->dev_buffer ;
5481-
5482- const uint64_t k_offset = vk_tensor_offset (k);
5483- const uint64_t v_offset = vk_tensor_offset (v);
5484- const uint64_t r_offset = vk_tensor_offset (r);
5485- const uint64_t tf_offset = vk_tensor_offset (tf);
5486- const uint64_t td_offset = vk_tensor_offset (td);
5487- const uint64_t state_offset = vk_tensor_offset (state);
5488- const uint64_t dst_offset = vk_tensor_offset (dst);
5474+ ggml_vk_sync_buffers (subctx);
5475+
5476+ vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5477+ uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5478+ bool K_uma = false , V_uma = false , R_uma = false , TF_uma = false , TD_uma = false , STATE_uma = false , DST_uma = false ;
5479+
5480+ if (ctx->device ->uma ) {
5481+ ggml_vk_host_get (ctx->device , k->data , d_K, k_offset);
5482+ ggml_vk_host_get (ctx->device , v->data , d_V, v_offset);
5483+ ggml_vk_host_get (ctx->device , r->data , d_R, r_offset);
5484+ ggml_vk_host_get (ctx->device , tf->data , d_TF, tf_offset);
5485+ ggml_vk_host_get (ctx->device , td->data , d_TD, td_offset);
5486+ ggml_vk_host_get (ctx->device , state->data , d_State, state_offset);
5487+ ggml_vk_host_get (ctx->device , dst->data , d_D, dst_offset);
5488+
5489+ K_uma = d_K != nullptr ;
5490+ V_uma = d_V != nullptr ;
5491+ R_uma = d_R != nullptr ;
5492+ TF_uma = d_TF != nullptr ;
5493+ TD_uma = d_TD != nullptr ;
5494+ STATE_uma = d_State != nullptr ;
5495+ DST_uma = d_D != nullptr ;
5496+ }
5497+
5498+ if (!K_uma) {
5499+ d_K = k_buf_ctx->dev_buffer ;
5500+ k_offset = vk_tensor_offset (k) + k->view_offs ;
5501+ }
5502+ if (!V_uma) {
5503+ d_V = v_buf_ctx->dev_buffer ;
5504+ v_offset = vk_tensor_offset (v) + v->view_offs ;
5505+ }
5506+ if (!R_uma) {
5507+ d_R = r_buf_ctx->dev_buffer ;
5508+ r_offset = vk_tensor_offset (r) + r->view_offs ;
5509+ }
5510+ if (!TF_uma) {
5511+ d_TF = tf_buf_ctx->dev_buffer ;
5512+ tf_offset = vk_tensor_offset (tf) + tf->view_offs ;
5513+ }
5514+ if (!TD_uma) {
5515+ d_TD = td_buf_ctx->dev_buffer ;
5516+ td_offset = vk_tensor_offset (td) + td->view_offs ;
5517+ }
5518+ if (!STATE_uma) {
5519+ d_State = state_buf_ctx->dev_buffer ;
5520+ state_offset = vk_tensor_offset (state) + state->view_offs ;
5521+ }
5522+ if (!DST_uma) {
5523+ d_D = dst_buf_ctx->dev_buffer ;
5524+ dst_offset = vk_tensor_offset (dst) + dst->view_offs ;
5525+ }
54895526
54905527 const uint64_t k_size = ggml_nbytes (k);
54915528 const uint64_t v_size = ggml_nbytes (v);
@@ -5501,7 +5538,7 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
55015538 1
55025539 };
55035540
5504- ggml_vk_sync_buffers (subctx);
5541+
55055542 ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, {
55065543 vk_subbuffer{ d_K, k_offset, k_size },
55075544 vk_subbuffer{ d_V, v_offset, v_size },
0 commit comments