@@ -423,6 +423,7 @@ struct vk_device_struct {
423423 bool multi_add;
424424 bool shader_int64;
425425 bool buffer_device_address;
426+ bool ext_external_memory_host;
426427
427428 bool add_rms_fusion;
428429 uint32_t partials_binding_alignment;
@@ -680,6 +681,9 @@ struct vk_buffer_struct {
680681
681682 vk_device device;
682683
684+ bool from_host_ptr = false;
685+ size_t alignment_offset = 0;
686+
683687 ~vk_buffer_struct() {
684688 if (size == 0) {
685689 return;
@@ -1500,13 +1504,6 @@ struct ggml_backend_vk_context {
15001504
15011505static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
15021506
1503- static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
1504- if (tensor->view_src) {
1505- return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base;
1506- }
1507- return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base;
1508- }
1509-
15101507struct ggml_backend_vk_buffer_context {
15111508 vk_device_ref device;
15121509 vk_buffer dev_buffer;
@@ -1523,6 +1520,16 @@ struct ggml_backend_vk_buffer_context {
15231520 }
15241521};
15251522
1523+ static uint64_t vk_tensor_offset(const ggml_tensor * tensor) {
1524+ ggml_backend_vk_buffer_context * buf_ctx = static_cast<ggml_backend_vk_buffer_context *>(tensor->buffer->context);
1525+ vk_buffer buf = buf_ctx->dev_buffer;
1526+
1527+ void * base_addr = buf->from_host_ptr ? buf->ptr : vk_ptr_base;
1528+ void * tensor_data = tensor->view_src ? tensor->view_src->data : tensor->data;
1529+
1530+ return (uint8_t *)tensor_data - (uint8_t *)base_addr;
1531+ }
1532+
15261533#ifdef GGML_VULKAN_MEMORY_DEBUG
15271534static std::mutex log_mutex;
15281535
@@ -2180,6 +2187,76 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
21802187 return buf;
21812188}
21822189
2190+ static vk_buffer ggml_vk_create_buffer_from_host_ptr(vk_device& device, void * ptr, size_t size) {
2191+ if (!device->ext_external_memory_host) {
2192+ throw std::runtime_error("VK_EXT_external_memory_host extension not available");
2193+ }
2194+
2195+ const size_t page_size = device->physical_device.getProperties().limits.minMemoryMapAlignment;
2196+ uintptr_t ptr_addr = reinterpret_cast<uintptr_t>(ptr);
2197+ uintptr_t page_aligned_base = ptr_addr & ~(page_size - 1);
2198+ void* aligned_ptr = reinterpret_cast<void*>(page_aligned_base);
2199+ size_t offset = ptr_addr - page_aligned_base;
2200+ size_t aligned_size = (size + offset + page_size - 1) & ~(page_size - 1);
2201+
2202+ vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer |
2203+ vk::BufferUsageFlagBits::eTransferSrc |
2204+ vk::BufferUsageFlagBits::eShaderDeviceAddress;
2205+
2206+ vk_buffer buf = std::make_shared<vk_buffer_struct>();
2207+
2208+ vk::BufferCreateInfo buffer_create_info{{}, aligned_size, usage_flags, vk::SharingMode::eExclusive};
2209+ buf->buffer = device->device.createBuffer(buffer_create_info);
2210+
2211+ vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
2212+ vk::MemoryPropertyFlags req_flags = vk::MemoryPropertyFlagBits::eHostVisible |
2213+ vk::MemoryPropertyFlagBits::eHostCoherent;
2214+
2215+ vk::MemoryRequirements modified_req = mem_req;
2216+
2217+ vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
2218+ uint32_t memory_type_index = find_properties(&mem_props, &modified_req, req_flags);
2219+ if (memory_type_index == UINT32_MAX) {
2220+ device->device.destroyBuffer(buf->buffer);
2221+ throw vk::OutOfDeviceMemoryError("No compatible memory type found");
2222+ }
2223+
2224+ VkImportMemoryHostPointerInfoEXT import_info = {
2225+ VK_STRUCTURE_TYPE_IMPORT_MEMORY_HOST_POINTER_INFO_EXT,
2226+ nullptr,
2227+ VK_EXTERNAL_MEMORY_HANDLE_TYPE_HOST_ALLOCATION_BIT_EXT,
2228+ aligned_ptr
2229+ };
2230+
2231+ VkMemoryAllocateInfo alloc_info = {
2232+ VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO,
2233+ &import_info,
2234+ aligned_size,
2235+ memory_type_index
2236+ };
2237+
2238+ buf->device_memory = device->device.allocateMemory(alloc_info);
2239+ device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
2240+
2241+ buf->ptr = aligned_ptr;
2242+ buf->size = aligned_size;
2243+ buf->alignment_offset = offset;
2244+ buf->from_host_ptr = true;
2245+ buf->device = device;
2246+ buf->memory_property_flags = req_flags;
2247+
2248+ if (device->buffer_device_address) {
2249+ const vk::BufferDeviceAddressInfo addressInfo(buf->buffer);
2250+ buf->bda_addr = device->device.getBufferAddress(addressInfo);
2251+ }
2252+
2253+ #ifdef GGML_VULKAN_MEMORY_DEBUG
2254+ device->memory_logger->log_allocation(buf, size);
2255+ #endif
2256+
2257+ return buf;
2258+ }
2259+
21832260static void ggml_vk_destroy_buffer(vk_buffer& buf) {
21842261 if (buf == nullptr) {
21852262 return;
@@ -3819,6 +3896,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
38193896 pipeline_robustness = true;
38203897 } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
38213898 device->subgroup_size_control = true;
3899+ } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
3900+ device->ext_external_memory_host = true;
38223901#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
38233902 } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
38243903 !getenv("GGML_VK_DISABLE_COOPMAT")) {
@@ -4223,6 +4302,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
42234302 device_extensions.push_back("VK_KHR_shader_float16_int8");
42244303 }
42254304
4305+ if (device->ext_external_memory_host) {
4306+ device_extensions.push_back("VK_EXT_external_memory_host");
4307+ }
4308+
42264309#if defined(VK_KHR_cooperative_matrix)
42274310 if (device->coopmat_support) {
42284311 // Query supported shapes
@@ -11835,9 +11918,13 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
1183511918}
1183611919
1183711920static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
11838- return vk_ptr_base;
11921+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
11922+ vk_buffer buf = buf_ctx->dev_buffer;
1183911923
11840- UNUSED(buffer);
11924+ if (buf->from_host_ptr) {
11925+ return (uint8_t*)buf->ptr + buf->alignment_offset;
11926+ }
11927+ return vk_ptr_base;
1184111928}
1184211929
1184311930static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
@@ -12876,6 +12963,25 @@ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(gg
1287612963 return ggml_backend_vk_host_buffer_type();
1287712964}
1287812965
12966+ static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
12967+ VK_LOG_MEMORY("ggml_backend_vk_device_buffer_from_host_ptr(" << size << ")");
12968+
12969+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
12970+ vk_device device = ggml_vk_get_device(ctx->device);
12971+
12972+ if (!device->uma) {
12973+ GGML_ABORT("ggml_backend_vk_device_buffer_from_host_ptr works only with UMA devices");
12974+ }
12975+
12976+ vk_buffer dev_buffer = ggml_vk_create_buffer_from_host_ptr(device, ptr, size);
12977+
12978+ ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(dev_buffer), ctx->name);
12979+ ggml_backend_buffer_type_t buft = ggml_backend_vk_device_get_buffer_type(dev);
12980+
12981+ UNUSED(max_tensor_size);
12982+ return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size);
12983+ }
12984+
1287912985static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
1288012986 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
1288112987
@@ -12884,6 +12990,7 @@ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_d
1288412990
1288512991static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
1288612992 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
12993+ vk_device device = ggml_vk_get_device(ctx->device);
1288712994
1288812995 props->name = ggml_backend_vk_device_get_name(dev);
1288912996 props->description = ggml_backend_vk_device_get_description(dev);
@@ -12893,7 +13000,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
1289313000 props->caps = {
1289413001 /* .async = */ false,
1289513002 /* .host_buffer = */ true,
12896- /* .buffer_from_host_ptr = */ false ,
13003+ /* .buffer_from_host_ptr = */ device->uma ,
1289713004 /* .events = */ false,
1289813005 };
1289913006}
@@ -13338,7 +13445,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
1333813445 /* .init_backend = */ ggml_backend_vk_device_init,
1333913446 /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
1334013447 /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
13341- /* .buffer_from_host_ptr = */ NULL ,
13448+ /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr ,
1334213449 /* .supports_op = */ ggml_backend_vk_device_supports_op,
1334313450 /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
1334413451 /* .offload_op = */ ggml_backend_vk_device_offload_op,
0 commit comments