@@ -406,6 +406,8 @@ struct vk_device_struct {
406406 bool subgroup_ballot;
407407 bool subgroup_clustered;
408408 bool multi_add;
409+ bool shader_int64;
410+ bool buffer_device_address;
409411
410412 bool add_rms_fusion;
411413 uint32_t partials_binding_alignment;
@@ -653,6 +655,7 @@ struct vk_buffer_struct {
653655 vk::MemoryPropertyFlags memory_property_flags;
654656 void * ptr;
655657 size_t size = 0;
658+ vk::DeviceAddress bda_addr {};
656659
657660 vk_device device;
658661
@@ -985,6 +988,7 @@ struct vk_op_argsort_push_constants {
985988};
986989
987990struct vk_op_im2col_push_constants {
991+ uint64_t dst_addr;
988992 uint32_t batch_offset; uint32_t offset_delta;
989993 uint32_t IC;
990994 uint32_t IW; uint32_t IH;
@@ -998,6 +1002,7 @@ struct vk_op_im2col_push_constants {
9981002};
9991003
10001004struct vk_op_im2col_3d_push_constants {
1005+ uint64_t dst_addr;
10011006 uint32_t nb10;
10021007 uint32_t nb11;
10031008 uint32_t nb12;
@@ -2010,10 +2015,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20102015 return buf;
20112016 }
20122017
2018+ vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
2019+ vk::MemoryAllocateFlags mem_flags {};
2020+ if (device->buffer_device_address) {
2021+ usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
2022+ mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
2023+ }
2024+
20132025 vk::BufferCreateInfo buffer_create_info{
20142026 vk::BufferCreateFlags(),
20152027 size,
2016- vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst ,
2028+ usage_flags ,
20172029 vk::SharingMode::eExclusive,
20182030 0,
20192031 nullptr,
@@ -2025,6 +2037,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20252037
20262038 vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
20272039
2040+ const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
2041+
20282042 for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
20292043 const auto & req_flags = *it;
20302044
@@ -2036,7 +2050,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20362050 buf->memory_property_flags = req_flags;
20372051
20382052 try {
2039- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
2053+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
20402054 break;
20412055 } catch (const vk::SystemError& e) {
20422056 // loop and retry
@@ -2064,6 +2078,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20642078 buf->device = device;
20652079 buf->size = size;
20662080
2081+ if (device->buffer_device_address) {
2082+ const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2083+ buf->bda_addr = device->device.getBufferAddress(addressInfo);
2084+ }
2085+
20672086#ifdef GGML_VULKAN_MEMORY_DEBUG
20682087 device->memory_logger->log_allocation(buf, size);
20692088#endif
@@ -3530,14 +3549,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
35303549
35313550 ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
35323551
3533- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3534- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3535- if (device->float_controls_rte_fp16) {
3536- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3537- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3552+ #define IM2COL(bda) \
3553+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3554+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3555+ if (device->float_controls_rte_fp16) { \
3556+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3557+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3558+ } else { \
3559+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3560+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3561+ }
3562+ if (device->shader_int64 && device->buffer_device_address) {
3563+ IM2COL(_bda)
35383564 } else {
3539- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3540- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3565+ IM2COL()
35413566 }
35423567
35433568 ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -4015,6 +4040,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
40154040 device->vendor_id != VK_VENDOR_ID_INTEL &&
40164041 getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
40174042
4043+ device->shader_int64 = device_features2.features.shaderInt64;
4044+ device->buffer_device_address = vk12_features.bufferDeviceAddress;
4045+
40184046 if (device->subgroup_size_control) {
40194047 device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
40204048 device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -9443,7 +9471,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
94439471
94449472 const uint32_t pelements = OW * KW * KH;
94459473
9474+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9475+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9476+
9477+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9478+
94469479 ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
9480+ dst_addr,
94479481 batch_offset, offset_delta,
94489482 IC, IW, IH, OW, OH, KW, KH,
94499483 pelements,
@@ -9479,8 +9513,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
94799513 const int64_t OH = ne2;
94809514 const int64_t OW = ne1;
94819515
9516+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9517+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9518+
9519+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9520+
94829521 vk_op_im2col_3d_push_constants pc {};
94839522
9523+ pc.dst_addr = dst_addr;
94849524 pc.nb10 = nb10 / ggml_type_size(src1->type);
94859525 pc.nb11 = nb11 / ggml_type_size(src1->type);
94869526 pc.nb12 = nb12 / ggml_type_size(src1->type);
0 commit comments