@@ -1201,6 +1201,14 @@ struct vk_staging_memcpy {
12011201 size_t n;
12021202};
12031203
1204+ struct vk_staging_memset {
1205+ vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {}
1206+
1207+ void * dst;
1208+ uint32_t val;
1209+ size_t n;
1210+ };
1211+
12041212struct vk_context_struct {
12051213 vk_submission * s;
12061214 std::vector<vk_sequence> seqs;
@@ -1209,6 +1217,7 @@ struct vk_context_struct {
12091217
12101218 std::vector<vk_staging_memcpy> in_memcpys;
12111219 std::vector<vk_staging_memcpy> out_memcpys;
1220+ std::vector<vk_staging_memset> memsets;
12121221
12131222 vk_command_pool * p {};
12141223};
@@ -1600,7 +1609,9 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
16001609 }
16011610
16021611 vk::ComputePipelineCreateInfo compute_pipeline_create_info(
1603- vk::PipelineCreateFlags{},
1612+ device->pipeline_executable_properties_support ?
1613+ vk::PipelineCreateFlagBits::eCaptureStatisticsKHR :
1614+ vk::PipelineCreateFlags{},
16041615 pipeline_shader_create_info,
16051616 pipeline->layout);
16061617
@@ -3396,7 +3407,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
33963407 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
33973408 ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
33983409
3399- CREATE_UNARY(exp)
34003410 CREATE_UNARY(gelu)
34013411 CREATE_UNARY(gelu_erf)
34023412 CREATE_UNARY(gelu_quick)
@@ -3408,6 +3418,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
34083418 CREATE_UNARY(hardswish)
34093419#undef CREATE_UNARY
34103420
3421+ #define CREATE_UNARY_RTE(name) \
3422+ if (device->float_controls_rte_fp16) { \
3423+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3424+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3425+ } else { \
3426+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3427+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
3428+ }
3429+ CREATE_UNARY_RTE(exp)
3430+ #undef CREATE_UNARY_RTE
3431+
34113432#define CREATE_GLU(name) \
34123433 if (device->float_controls_rte_fp16) { \
34133434 ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@@ -5224,6 +5245,14 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
52245245 }
52255246}
52265247
5248+ static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector<vk_staging_memset>* memsets = nullptr) {
5249+ if (memsets == nullptr) {
5250+ memset(dst, val, size);
5251+ } else {
5252+ memsets->emplace_back(dst, val, size);
5253+ }
5254+ }
5255+
52275256static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) {
52285257 if (device->sync_staging == nullptr || device->sync_staging->size < size) {
52295258 VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")");
@@ -5419,6 +5448,10 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
54195448 memcpy(cpy.dst, cpy.src, cpy.n);
54205449 }
54215450
5451+ for (auto& mset : subctx->memsets) {
5452+ memset(mset.dst, mset.val, mset.n);
5453+ }
5454+
54225455 ggml_vk_submit(subctx, dst->device->fence);
54235456 VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences");
54245457 dst->device->device.resetFences({ dst->device->fence });
@@ -5558,12 +5591,25 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
55585591static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
55595592 VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")");
55605593
5594+ if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
5595+ dst->device->uma) {
5596+ deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets);
5597+ return;
5598+ }
5599+
5600+ // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
55615601 ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
55625602}
55635603
55645604static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
55655605 VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")");
55665606
5607+ if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible &&
5608+ dst->device->uma) {
5609+ memset((uint8_t*)dst->ptr + offset, c, size);
5610+ return;
5611+ }
5612+
55675613 std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
55685614 vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
55695615 ggml_vk_ctx_begin(dst->device, subctx);
@@ -11198,6 +11244,10 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1119811244 memcpy(cpy.dst, cpy.src, cpy.n);
1119911245 }
1120011246
11247+ for (auto& mset : subctx->memsets) {
11248+ memset(mset.dst, mset.val, mset.n);
11249+ }
11250+
1120111251 if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) {
1120211252 ggml_vk_submit(subctx, ctx->almost_ready_fence);
1120311253 ctx->almost_ready_fence_pending = true;
@@ -11220,6 +11270,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1122011270 }
1122111271 subctx->in_memcpys.clear();
1122211272 subctx->out_memcpys.clear();
11273+ subctx->memsets.clear();
1122311274 }
1122411275
1122511276 return true;
0 commit comments