@@ -5200,7 +5200,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
52005200 device->pinned_memory.erase(device->pinned_memory.begin() + index);
52015201}
52025202
5203- static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
5203+ static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
52045204 std::lock_guard<std::recursive_mutex> guard(device->mutex);
52055205 buf = nullptr;
52065206 buf_offset = 0;
@@ -8295,6 +8295,45 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
82958295 GGML_UNUSED(src2);
82968296}
82978297
8298+ static vk_subbuffer ggml_vk_tensor_subbuffer(
8299+ const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool support_incontiguous,
8300+ vk_buffer buffer = nullptr, size_t offset = 0) {
8301+
8302+ if (!buffer) {
8303+ auto buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
8304+ buffer = buf_ctx->dev_buffer;
8305+ offset = vk_tensor_offset(tensor) + tensor->view_offs;
8306+ }
8307+ GGML_ASSERT(buffer != nullptr);
8308+
8309+ size_t misalign_bytes = offset & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
8310+ offset &= ~misalign_bytes;
8311+
8312+ size_t size;
8313+ if (support_incontiguous) {
8314+ size = ggml_nbytes(tensor) + misalign_bytes;
8315+ if (offset + size >= buffer->size) {
8316+ size = ggml_vk_get_max_buffer_range(ctx, buffer, offset);
8317+ }
8318+ } else {
8319+ size_t elem_size = ggml_type_size(tensor->type) / ggml_blck_size(tensor->type);
8320+ size = elem_size * ggml_nelements(tensor);
8321+ }
8322+
8323+ return vk_subbuffer{buffer, offset, size};
8324+ }
8325+
8326+ static vk_subbuffer ggml_vk_tensor_subbuffer_uma(
8327+ const ggml_backend_vk_context * ctx, const ggml_tensor * tensor, bool support_incontiguous) {
8328+
8329+ vk_buffer buffer = nullptr;
8330+ size_t offset = 0;
8331+ if (ctx->device->uma) {
8332+ ggml_vk_host_get(ctx->device, tensor->data, buffer, offset);
8333+ }
8334+ return ggml_vk_tensor_subbuffer(ctx, tensor, support_incontiguous, std::move(buffer), offset);
8335+ }
8336+
82988337template<typename PC>
82998338static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
83008339 VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
@@ -8356,60 +8395,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
83568395
83578396 const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
83588397
8359- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
8360- ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
8361- ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr;
8362- ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr;
8363-
8364- vk_buffer d_X = nullptr;
8365- size_t x_buf_offset = 0;
8366- vk_buffer d_Y = nullptr;
8367- size_t y_buf_offset = 0;
8368- vk_buffer d_Z = nullptr;
8369- size_t z_buf_offset = 0;
8370-
8371- bool src0_uma = false;
8372- bool src1_uma = false;
8373- bool src2_uma = false;
8374-
8375- if (ctx->device->uma) {
8376- ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset);
8377- src0_uma = d_X != nullptr;
8378- if (use_src1) {
8379- ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset);
8380- src1_uma = d_Y != nullptr;
8381- }
8382- if (use_src2) {
8383- ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset);
8384- src2_uma = d_Z != nullptr;
8385- }
8386- }
8387-
8388- vk_buffer d_D = dst_buf_ctx->dev_buffer;
8398+ vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer_uma(ctx, src0, op_supports_incontiguous);
8399+ vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer_uma(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
8400+ vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer_uma(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
8401+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
83898402
8390- GGML_ASSERT(d_D != nullptr);
8391- uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
8392- if(!src0_uma) {
8393- d_X = src0_buf_ctx->dev_buffer;
8394- x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
8395- GGML_ASSERT(d_X != nullptr);
8396- }
8397- if (use_src1 && !src1_uma) {
8398- d_Y = src1_buf_ctx->dev_buffer;
8399- y_buf_offset = vk_tensor_offset(src1) + src1->view_offs;
8400- GGML_ASSERT(d_Y != nullptr);
8401- }
8402- if (use_src2 && !src2_uma) {
8403- d_Z = src2_buf_ctx->dev_buffer;
8404- z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
8405- GGML_ASSERT(d_Z != nullptr);
8406- }
8407- // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
8403+ // Compute misalignment offset for descriptors and store it in in push constants.
84088404 init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst);
8409- x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
8410- y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
8411- z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
8412- d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
84138405
84148406 std::array<uint32_t, 3> elements;
84158407
@@ -8609,100 +8601,47 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
86098601 break;
86108602 }
86118603
8612- uint64_t x_sz, y_sz, z_sz, d_sz;
8613-
8614- if (op_supports_incontiguous) {
8615- x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0);
8616- y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0;
8617- z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0;
8618- d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst);
8619-
8620- if (x_buf_offset + x_sz >= d_X->size) {
8621- x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset);
8622- }
8623- if (use_src1 && y_buf_offset + y_sz >= d_Y->size) {
8624- y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset);
8625- }
8626- if (use_src2 && z_buf_offset + z_sz >= d_Z->size) {
8627- z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset);
8628- }
8629- if (d_buf_offset + d_sz >= d_D->size) {
8630- d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset);
8631- }
8632- } else {
8633- x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03;
8634- y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0;
8635- z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0;
8636- d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3;
8637- }
8638-
86398604 if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) {
8640- vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X;
8641- size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0;
8605+ vk_subbuffer a_buf = src0_buf;
8606+ if (ctx->do_add_rms_partials) {
8607+ a_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);
8608+ }
86428609 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8643- { vk_subbuffer{ d_X, x_buf_offset, x_sz },
8644- vk_subbuffer{ d_Y, y_buf_offset, y_sz },
8645- vk_subbuffer{ d_D, d_buf_offset, d_sz },
8646- ggml_vk_subbuffer(ctx, d_A, a_buf_offset),
8647- }, pc, elements);
8610+ { src0_buf, src1_buf, dst_buf, a_buf }, pc, elements);
86488611 } else if (op == GGML_OP_GLU) {
86498612 // Empty src1 is possible in glu, but the shader needs a buffer
8650- vk_subbuffer subbuf_y;
8651- if (use_src1) {
8652- subbuf_y = { d_Y, y_buf_offset, y_sz };
8653- } else {
8654- subbuf_y = { d_X, 0, x_sz };
8655- }
8656-
8657- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8613+ vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
8614+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, dst_buf }, pc, elements);
86588615 } else if (op == GGML_OP_SOFT_MAX) {
86598616 // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
8660- vk_subbuffer subbuf_y;
8661- if (use_src1) {
8662- subbuf_y = { d_Y, y_buf_offset, y_sz };
8663- } else {
8664- subbuf_y = { d_X, 0, x_sz };
8665- }
8666-
8667- vk_subbuffer subbuf_z;
8668- if (use_src2) {
8669- subbuf_z = { d_Z, z_buf_offset, z_sz };
8670- } else {
8671- subbuf_z = { d_X, 0, x_sz };
8672- }
8673-
8674- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8617+ vk_subbuffer subbuf1 = use_src1 ? src1_buf : src0_buf;
8618+ vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
8619+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, subbuf1, subbuf2, dst_buf }, pc, elements);
86758620 } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
86768621 // Empty src2 is possible in rope, but the shader needs a buffer
8677- vk_subbuffer subbuf_z;
8678- if (use_src2) {
8679- subbuf_z = { d_Z, z_buf_offset, z_sz };
8680- } else {
8681- subbuf_z = { d_X, 0, x_sz };
8682- }
8683-
8684- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8622+ vk_subbuffer subbuf2 = use_src2 ? src2_buf : src0_buf;
8623+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, subbuf2, dst_buf }, pc, elements);
86858624 } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
86868625 if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
86878626 // buffer device address path doesn't use dst buffer
8688- d_sz = 1;
8627+ dst_buf.size = 1;
86898628 }
86908629 // im2col uses only src1 and dst buffers
8691- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8630+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src1_buf, dst_buf }, pc, elements);
86928631 } else if (op == GGML_OP_COUNT_EQUAL) {
86938632 // count_equal assumes that destination buffer is initialized with zeroes
8694- ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset , 0, d_sz );
8633+ ggml_vk_buffer_memset_async(subctx, dst_buf.buffer, dst_buf.offset , 0, dst_buf.size );
86958634 ggml_vk_sync_buffers(ctx, subctx);
8696- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8635+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
86978636 } else if (op == GGML_OP_OPT_STEP_SGD) {
86988637 // OPT_STEP_SGD works on src0, it does not need dst
8699- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
8638+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf }, pc, elements);
87008639 } else if (use_src2) {
8701- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8640+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, src2_buf, dst_buf }, pc, elements);
87028641 } else if (use_src1) {
8703- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8642+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, src1_buf, dst_buf }, pc, elements);
87048643 } else {
8705- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
8644+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src0_buf, dst_buf }, pc, elements);
87068645 }
87078646}
87088647
0 commit comments