@@ -5382,9 +5382,13 @@ static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffe
53825382}
53835383
53845384static vk_subbuffer ggml_vk_tensor_subbuffer(
5385- const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false,
5386- vk_buffer buffer = nullptr, size_t offset = 0) {
5385+ const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
53875386
5387+ vk_buffer buffer = nullptr;
5388+ size_t offset = 0;
5389+ if (ctx->device->uma) {
5390+ ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
5391+ }
53885392 if (!buffer) {
53895393 auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
53905394 buffer = buf_ctx->dev_buffer;
@@ -5403,17 +5407,6 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
54035407 return vk_subbuffer{buffer, offset, size};
54045408}
54055409
5406- static vk_subbuffer ggml_vk_tensor_subbuffer_uma(
5407- const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
5408-
5409- vk_buffer buffer = nullptr;
5410- size_t offset = 0;
5411- if (ctx->device->uma) {
5412- ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
5413- }
5414- return ggml_vk_tensor_subbuffer(ctx, tensor, allow_misalign, std::move(buffer), offset);
5415- }
5416-
54175410static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
54185411 vk_submission s;
54195412 s.buffer = ggml_vk_create_cmd_buffer(device, p);
@@ -7772,12 +7765,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
77727765 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
77737766 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
77747767
7775- vk_subbuffer q_buf = ggml_vk_tensor_subbuffer_uma (ctx, q);
7776- vk_subbuffer k_buf = ggml_vk_tensor_subbuffer_uma (ctx, k);
7777- vk_subbuffer v_buf = ggml_vk_tensor_subbuffer_uma (ctx, v);
7778- vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma (ctx, dst);
7779- vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer_uma (ctx, mask) : q_buf;
7780- vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer_uma (ctx, sinks) : q_buf;
7768+ vk_subbuffer q_buf = ggml_vk_tensor_subbuffer (ctx, q);
7769+ vk_subbuffer k_buf = ggml_vk_tensor_subbuffer (ctx, k);
7770+ vk_subbuffer v_buf = ggml_vk_tensor_subbuffer (ctx, v);
7771+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer (ctx, dst);
7772+ vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer (ctx, mask) : q_buf;
7773+ vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer (ctx, sinks) : q_buf;
77817774
77827775 uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
77837776
@@ -8529,10 +8522,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
85298522
85308523 const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
85318524
8532- vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer_uma (ctx, src0, op_supports_incontiguous);
8533- vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer_uma (ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
8534- vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer_uma (ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
8535- vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer_uma (ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
8525+ vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer (ctx, src0, op_supports_incontiguous);
8526+ vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer (ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
8527+ vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer (ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
8528+ vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer (ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
85368529 vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
85378530
85388531 // Compute misalignment offset for descriptors and store it in in push constants.
@@ -9022,10 +9015,10 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
90229015 return;
90239016 }
90249017
9025- vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma (ctx, dst);
9018+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer (ctx, dst);
90269019 vk_subbuffer src_buf[7] = {};
90279020 for (int i = 0; i < num_srcs; i++) {
9028- src_buf[i] = ggml_vk_tensor_subbuffer_uma (ctx, dst->src[i]);
9021+ src_buf[i] = ggml_vk_tensor_subbuffer (ctx, dst->src[i]);
90299022 }
90309023
90319024 std::array<uint32_t, 3> elements = {
@@ -9126,10 +9119,10 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
91269119 n_head, head_dim, n_group, n_tok
91279120 };
91289121
9129- vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma (ctx, dst);
9122+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer (ctx, dst);
91309123 vk_subbuffer src_buf[7] = {};
91319124 for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {
9132- src_buf[i] = ggml_vk_tensor_subbuffer_uma (ctx, dst->src[i]);
9125+ src_buf[i] = ggml_vk_tensor_subbuffer (ctx, dst->src[i]);
91339126 }
91349127
91359128 std::array<uint32_t, 3> elements;
@@ -9191,11 +9184,11 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
91919184 return;
91929185 }
91939186
9194- vk_subbuffer x_buf = ggml_vk_tensor_subbuffer_uma (ctx, x);
9195- vk_subbuffer g_buf = ggml_vk_tensor_subbuffer_uma (ctx, g);
9196- vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer_uma (ctx, gm);
9197- vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer_uma (ctx, gv);
9198- vk_subbuffer p_buf = ggml_vk_tensor_subbuffer_uma (ctx, p);
9187+ vk_subbuffer x_buf = ggml_vk_tensor_subbuffer (ctx, x);
9188+ vk_subbuffer g_buf = ggml_vk_tensor_subbuffer (ctx, g);
9189+ vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer (ctx, gm);
9190+ vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer (ctx, gv);
9191+ vk_subbuffer p_buf = ggml_vk_tensor_subbuffer (ctx, p);
91999192
92009193 std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
92019194
@@ -9537,9 +9530,9 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
95379530 return;
95389531 }
95399532
9540- vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer_uma (ctx, logits);
9541- vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer_uma (ctx, weights);
9542- vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer_uma (ctx, ids);
9533+ vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer (ctx, logits);
9534+ vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer (ctx, weights);
9535+ vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer (ctx, ids);
95439536
95449537 vk_op_topk_moe_push_constants pc {};
95459538 pc.n_rows = n_rows;
0 commit comments