Skip to content

Commit 353c5f8

Browse files
committed
add uma support
1 parent 6ea605d commit 353c5f8

File tree

1 file changed

+53
-16
lines changed

1 file changed

+53
-16
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5471,21 +5471,58 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
54715471
ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
54725472
ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
54735473

5474-
vk_buffer d_D = dst_buf_ctx->dev_buffer;
5475-
vk_buffer d_K = k_buf_ctx->dev_buffer;
5476-
vk_buffer d_V = v_buf_ctx->dev_buffer;
5477-
vk_buffer d_R = r_buf_ctx->dev_buffer;
5478-
vk_buffer d_TF = tf_buf_ctx->dev_buffer;
5479-
vk_buffer d_TD = td_buf_ctx->dev_buffer;
5480-
vk_buffer d_State = state_buf_ctx->dev_buffer;
5481-
5482-
const uint64_t k_offset = vk_tensor_offset(k);
5483-
const uint64_t v_offset = vk_tensor_offset(v);
5484-
const uint64_t r_offset = vk_tensor_offset(r);
5485-
const uint64_t tf_offset = vk_tensor_offset(tf);
5486-
const uint64_t td_offset = vk_tensor_offset(td);
5487-
const uint64_t state_offset = vk_tensor_offset(state);
5488-
const uint64_t dst_offset = vk_tensor_offset(dst);
5474+
ggml_vk_sync_buffers(subctx);
5475+
5476+
vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5477+
uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5478+
bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
5479+
5480+
if (ctx->device->uma) {
5481+
ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
5482+
ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
5483+
ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
5484+
ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
5485+
ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
5486+
ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
5487+
ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
5488+
5489+
K_uma = d_K != nullptr;
5490+
V_uma = d_V != nullptr;
5491+
R_uma = d_R != nullptr;
5492+
TF_uma = d_TF != nullptr;
5493+
TD_uma = d_TD != nullptr;
5494+
STATE_uma = d_State != nullptr;
5495+
DST_uma = d_D != nullptr;
5496+
}
5497+
5498+
if (!K_uma) {
5499+
d_K = k_buf_ctx->dev_buffer;
5500+
k_offset = vk_tensor_offset(k) + k->view_offs;
5501+
}
5502+
if (!V_uma) {
5503+
d_V = v_buf_ctx->dev_buffer;
5504+
v_offset = vk_tensor_offset(v) + v->view_offs;
5505+
}
5506+
if (!R_uma) {
5507+
d_R = r_buf_ctx->dev_buffer;
5508+
r_offset = vk_tensor_offset(r) + r->view_offs;
5509+
}
5510+
if (!TF_uma) {
5511+
d_TF = tf_buf_ctx->dev_buffer;
5512+
tf_offset = vk_tensor_offset(tf) + tf->view_offs;
5513+
}
5514+
if (!TD_uma) {
5515+
d_TD = td_buf_ctx->dev_buffer;
5516+
td_offset = vk_tensor_offset(td) + td->view_offs;
5517+
}
5518+
if (!STATE_uma) {
5519+
d_State = state_buf_ctx->dev_buffer;
5520+
state_offset = vk_tensor_offset(state) + state->view_offs;
5521+
}
5522+
if (!DST_uma) {
5523+
d_D = dst_buf_ctx->dev_buffer;
5524+
dst_offset = vk_tensor_offset(dst) + dst->view_offs;
5525+
}
54895526

54905527
const uint64_t k_size = ggml_nbytes(k);
54915528
const uint64_t v_size = ggml_nbytes(v);
@@ -5501,7 +5538,7 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
55015538
1
55025539
};
55035540

5504-
ggml_vk_sync_buffers(subctx);
5541+
55055542
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
55065543
vk_subbuffer{ d_K, k_offset, k_size },
55075544
vk_subbuffer{ d_V, v_offset, v_size },

0 commit comments

Comments
 (0)