@@ -408,6 +408,8 @@ struct vk_device_struct {
408408 bool subgroup_ballot;
409409 bool subgroup_clustered;
410410 bool multi_add;
411+ bool shader_int64;
412+ bool buffer_device_address;
411413
412414 bool add_rms_fusion;
413415 uint32_t partials_binding_alignment;
@@ -655,6 +657,7 @@ struct vk_buffer_struct {
655657 vk::MemoryPropertyFlags memory_property_flags;
656658 void * ptr;
657659 size_t size = 0;
660+ vk::DeviceAddress bda_addr {};
658661
659662 vk_device device;
660663
@@ -987,6 +990,7 @@ struct vk_op_argsort_push_constants {
987990};
988991
989992struct vk_op_im2col_push_constants {
993+ uint64_t dst_addr;
990994 uint32_t batch_offset; uint32_t offset_delta;
991995 uint32_t IC;
992996 uint32_t IW; uint32_t IH;
@@ -1000,6 +1004,7 @@ struct vk_op_im2col_push_constants {
10001004};
10011005
10021006struct vk_op_im2col_3d_push_constants {
1007+ uint64_t dst_addr;
10031008 uint32_t nb10;
10041009 uint32_t nb11;
10051010 uint32_t nb12;
@@ -2012,10 +2017,17 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20122017 return buf;
20132018 }
20142019
2020+ vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst;
2021+ vk::MemoryAllocateFlags mem_flags {};
2022+ if (device->buffer_device_address) {
2023+ usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress;
2024+ mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress;
2025+ }
2026+
20152027 vk::BufferCreateInfo buffer_create_info{
20162028 vk::BufferCreateFlags(),
20172029 size,
2018- vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst ,
2030+ usage_flags ,
20192031 vk::SharingMode::eExclusive,
20202032 0,
20212033 nullptr,
@@ -2027,6 +2039,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20272039
20282040 vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
20292041
2042+ const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags };
2043+
20302044 for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
20312045 const auto & req_flags = *it;
20322046
@@ -2038,7 +2052,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20382052 buf->memory_property_flags = req_flags;
20392053
20402054 try {
2041- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
2055+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info });
20422056 break;
20432057 } catch (const vk::SystemError& e) {
20442058 // loop and retry
@@ -2066,6 +2080,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
20662080 buf->device = device;
20672081 buf->size = size;
20682082
2083+ if (device->buffer_device_address) {
2084+ const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2085+ buf->bda_addr = device->device.getBufferAddress(addressInfo);
2086+ }
2087+
20692088#ifdef GGML_VULKAN_MEMORY_DEBUG
20702089 device->memory_logger->log_allocation(buf, size);
20712090#endif
@@ -3532,14 +3551,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
35323551
35333552 ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
35343553
3535- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3536- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3537- if (device->float_controls_rte_fp16) {
3538- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3539- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3554+ #define IM2COL(bda) \
3555+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3556+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3557+ if (device->float_controls_rte_fp16) { \
3558+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3559+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3560+ } else { \
3561+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
3562+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
3563+ }
3564+ if (device->shader_int64 && device->buffer_device_address) {
3565+ IM2COL(_bda)
35403566 } else {
3541- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
3542- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
3567+ IM2COL()
35433568 }
35443569
35453570 ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -4017,6 +4042,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
40174042 device->vendor_id != VK_VENDOR_ID_INTEL &&
40184043 getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
40194044
4045+ device->shader_int64 = device_features2.features.shaderInt64;
4046+ device->buffer_device_address = vk12_features.bufferDeviceAddress;
4047+
40204048 if (device->subgroup_size_control) {
40214049 device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
40224050 device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
@@ -8635,6 +8663,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
86358663
86368664 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
86378665 } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
8666+ if (ctx->device->shader_int64 && ctx->device->buffer_device_address) {
8667+ // buffer device address path doesn't use dst buffer
8668+ d_sz = 1;
8669+ }
86388670 // im2col uses only src1 and dst buffers
86398671 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
86408672 } else if (op == GGML_OP_COUNT_EQUAL) {
@@ -9486,7 +9518,13 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
94869518
94879519 const uint32_t pelements = OW * KW * KH;
94889520
9521+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9522+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9523+
9524+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9525+
94899526 ggml_vk_op_f32<vk_op_im2col_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, {
9527+ dst_addr,
94909528 batch_offset, offset_delta,
94919529 IC, IW, IH, OW, OH, KW, KH,
94929530 pelements,
@@ -9522,8 +9560,14 @@ static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx,
95229560 const int64_t OH = ne2;
95239561 const int64_t OW = ne1;
95249562
9563+ const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
9564+ const vk_buffer d_buf = d_buf_ctx->dev_buffer;
9565+
9566+ const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs;
9567+
95259568 vk_op_im2col_3d_push_constants pc {};
95269569
9570+ pc.dst_addr = dst_addr;
95279571 pc.nb10 = nb10 / ggml_type_size(src1->type);
95289572 pc.nb11 = nb11 / ggml_type_size(src1->type);
95299573 pc.nb12 = nb12 / ggml_type_size(src1->type);
0 commit comments