Skip to content

Commit cbdf516

Browse files
committed
Combine UMA handling into ggml_vk_tensor_subbuffer
1 parent edacb52 commit cbdf516

File tree

1 file changed

+28
-35
lines changed

1 file changed

+28
-35
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 28 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5382,9 +5382,13 @@ static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffe
53825382
}
53835383

53845384
static vk_subbuffer ggml_vk_tensor_subbuffer(
5385-
const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false,
5386-
vk_buffer buffer = nullptr, size_t offset = 0) {
5385+
const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
53875386

5387+
vk_buffer buffer = nullptr;
5388+
size_t offset = 0;
5389+
if (ctx->device->uma) {
5390+
ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
5391+
}
53885392
if (!buffer) {
53895393
auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
53905394
buffer = buf_ctx->dev_buffer;
@@ -5403,17 +5407,6 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
54035407
return vk_subbuffer{buffer, offset, size};
54045408
}
54055409

5406-
static vk_subbuffer ggml_vk_tensor_subbuffer_uma(
5407-
const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool allow_misalign = false) {
5408-
5409-
vk_buffer buffer = nullptr;
5410-
size_t offset = 0;
5411-
if (ctx->device->uma) {
5412-
ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
5413-
}
5414-
return ggml_vk_tensor_subbuffer(ctx, tensor, allow_misalign, std::move(buffer), offset);
5415-
}
5416-
54175410
static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
54185411
vk_submission s;
54195412
s.buffer = ggml_vk_create_cmd_buffer(device, p);
@@ -7772,12 +7765,12 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
77727765
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
77737766
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
77747767

7775-
vk_subbuffer q_buf = ggml_vk_tensor_subbuffer_uma(ctx, q);
7776-
vk_subbuffer k_buf = ggml_vk_tensor_subbuffer_uma(ctx, k);
7777-
vk_subbuffer v_buf = ggml_vk_tensor_subbuffer_uma(ctx, v);
7778-
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma(ctx, dst);
7779-
vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer_uma(ctx, mask) : q_buf;
7780-
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer_uma(ctx, sinks) : q_buf;
7768+
vk_subbuffer q_buf = ggml_vk_tensor_subbuffer(ctx, q);
7769+
vk_subbuffer k_buf = ggml_vk_tensor_subbuffer(ctx, k);
7770+
vk_subbuffer v_buf = ggml_vk_tensor_subbuffer(ctx, v);
7771+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
7772+
vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
7773+
vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
77817774

77827775
uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
77837776

@@ -8529,10 +8522,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
85298522

85308523
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
85318524

8532-
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer_uma(ctx, src0, op_supports_incontiguous);
8533-
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer_uma(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
8534-
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer_uma(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
8535-
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer_uma(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
8525+
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
8526+
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
8527+
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
8528+
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
85368529
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
85378530

85388531
// Compute misalignment offset for descriptors and store it in in push constants.
@@ -9022,10 +9015,10 @@ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx
90229015
return;
90239016
}
90249017

9025-
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma(ctx, dst);
9018+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
90269019
vk_subbuffer src_buf[7] = {};
90279020
for (int i = 0; i < num_srcs; i++) {
9028-
src_buf[i] = ggml_vk_tensor_subbuffer_uma(ctx, dst->src[i]);
9021+
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
90299022
}
90309023

90319024
std::array<uint32_t, 3> elements = {
@@ -9126,10 +9119,10 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
91269119
n_head, head_dim, n_group, n_tok
91279120
};
91289121

9129-
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer_uma(ctx, dst);
9122+
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
91309123
vk_subbuffer src_buf[7] = {};
91319124
for (int i = 0; i < 7 && dst->src[i] != nullptr; i++) {
9132-
src_buf[i] = ggml_vk_tensor_subbuffer_uma(ctx, dst->src[i]);
9125+
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
91339126
}
91349127

91359128
std::array<uint32_t, 3> elements;
@@ -9191,11 +9184,11 @@ static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_cont
91919184
return;
91929185
}
91939186

9194-
vk_subbuffer x_buf = ggml_vk_tensor_subbuffer_uma(ctx, x);
9195-
vk_subbuffer g_buf = ggml_vk_tensor_subbuffer_uma(ctx, g);
9196-
vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer_uma(ctx, gm);
9197-
vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer_uma(ctx, gv);
9198-
vk_subbuffer p_buf = ggml_vk_tensor_subbuffer_uma(ctx, p);
9187+
vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
9188+
vk_subbuffer g_buf = ggml_vk_tensor_subbuffer(ctx, g);
9189+
vk_subbuffer gm_buf = ggml_vk_tensor_subbuffer(ctx, gm);
9190+
vk_subbuffer gv_buf = ggml_vk_tensor_subbuffer(ctx, gv);
9191+
vk_subbuffer p_buf = ggml_vk_tensor_subbuffer(ctx, p);
91999192

92009193
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
92019194

@@ -9537,9 +9530,9 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
95379530
return;
95389531
}
95399532

9540-
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer_uma(ctx, logits);
9541-
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer_uma(ctx, weights);
9542-
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer_uma(ctx, ids);
9533+
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
9534+
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
9535+
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
95439536

95449537
vk_op_topk_moe_push_constants pc {};
95459538
pc.n_rows = n_rows;

0 commit comments

Comments
 (0)