diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 1f1136382e360..12ff864a26968 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1185,6 +1185,14 @@ struct vk_staging_memcpy { size_t n; }; +struct vk_staging_memset { + vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {} + + void * dst; + uint32_t val; + size_t n; +}; + struct vk_context_struct { vk_submission * s; std::vector seqs; @@ -1193,6 +1201,7 @@ struct vk_context_struct { std::vector in_memcpys; std::vector out_memcpys; + std::vector memsets; vk_command_pool * p {}; }; @@ -5194,6 +5203,14 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect } } +static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector* memsets = nullptr) { + if (memsets == nullptr) { + memset(dst, val, size); + } else { + memsets->emplace_back(dst, val, size); + } +} + static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { if (device->sync_staging == nullptr || device->sync_staging->size < size) { VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); @@ -5389,6 +5406,10 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * memcpy(cpy.dst, cpy.src, cpy.n); } + for (auto& mset : subctx->memsets) { + memset(mset.dst, mset.val, mset.n); + } + ggml_vk_submit(subctx, dst->device->fence); VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); dst->device->device.resetFences({ dst->device->fence }); @@ -5528,12 +5549,25 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")"); + if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && + dst->device->uma) { + deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets); + return; + } + + // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); } static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); + if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && + dst->device->uma) { + memset((uint8_t*)dst->ptr + offset, c, size); + return; + } + std::lock_guard guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); @@ -11168,6 +11202,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * memcpy(cpy.dst, cpy.src, cpy.n); } + for (auto& mset : subctx->memsets) { + memset(mset.dst, mset.val, mset.n); + } + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { ggml_vk_submit(subctx, ctx->almost_ready_fence); ctx->almost_ready_fence_pending = true; @@ -11190,6 +11228,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * } subctx->in_memcpys.clear(); subctx->out_memcpys.clear(); + subctx->memsets.clear(); } return true;