@@ -274,6 +274,10 @@ struct vk_device_struct {
274274
275275 ggml_backend_buffer_type buffer_type;
276276
277+ VkMemoryType memory_type[VK_MAX_MEMORY_TYPES] {};
278+ VkMemoryHeap memory_heap[VK_MAX_MEMORY_HEAPS] {};
279+ std::atomic<VkDeviceSize> memory_usage[VK_MAX_MEMORY_TYPES];
280+
277281#ifdef GGML_VULKAN_MEMORY_DEBUG
278282 std::unique_ptr<vk_memory_logger> memory_logger;
279283#endif
@@ -311,6 +315,7 @@ struct vk_buffer_struct {
311315 vk::Buffer buffer = VK_NULL_HANDLE;
312316 vk::DeviceMemory device_memory = VK_NULL_HANDLE;
313317 vk::MemoryPropertyFlags memory_property_flags;
318+ uint32_t memory_type_index {};
314319 void * ptr;
315320 size_t size = 0 ;
316321
@@ -322,6 +327,7 @@ struct vk_buffer_struct {
322327 }
323328 VK_LOG_DEBUG (" ~vk_buffer_struct(" << buffer << " , " << size << " )" );
324329
330+ device->memory_usage [memory_type_index] -= size;
325331 device->device .freeMemory (device_memory);
326332 device->device .destroyBuffer (buffer);
327333 }
@@ -1229,6 +1235,18 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
12291235 uint32_t memory_type_index = UINT32_MAX;
12301236
12311237 memory_type_index = find_properties (&mem_props, &mem_req, req_flags);
1238+
1239+ // Avoid allocating "too much" host visible vidmem. Large HVV allocations may be contiguous
1240+ // and can fall back to sysmem due to fragmentation.
1241+ if (!device->uma &&
1242+ req_flags == (vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent)) {
1243+ uint32_t heap_index = device->memory_type [memory_type_index].heapIndex ;
1244+ if (device->memory_usage [memory_type_index] + size*2 >= device->memory_heap [heap_index].size ) {
1245+ req_flags = fallback_flags;
1246+ memory_type_index = find_properties (&mem_props, &mem_req, req_flags);
1247+ }
1248+ }
1249+
12321250 buf->memory_property_flags = req_flags;
12331251
12341252 if (memory_type_index == UINT32_MAX && fallback_flags) {
@@ -1262,6 +1280,9 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
12621280 throw e;
12631281 }
12641282 }
1283+ device->memory_usage [memory_type_index] += mem_req.size ;
1284+
1285+ buf->memory_type_index = memory_type_index;
12651286 buf->ptr = nullptr ;
12661287
12671288 if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
@@ -2186,6 +2207,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
21862207 }
21872208
21882209 device->physical_device = physical_devices[dev_num];
2210+
2211+ vk::PhysicalDeviceMemoryProperties memprops = device->physical_device .getMemoryProperties ();
2212+ for (uint32_t i = 0 ; i < VK_MAX_MEMORY_HEAPS; ++i) {
2213+ device->memory_heap [i] = memprops.memoryHeaps [i];
2214+ }
2215+ for (uint32_t i = 0 ; i < VK_MAX_MEMORY_TYPES; ++i) {
2216+ device->memory_type [i] = memprops.memoryTypes [i];
2217+ device->memory_usage [i] = 0 ;
2218+ }
2219+
21892220 const std::vector<vk::ExtensionProperties> ext_props = device->physical_device .enumerateDeviceExtensionProperties ();
21902221
21912222 bool fp16_storage = false ;
@@ -3231,6 +3262,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
32313262 if (!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
32323263 fprintf (stderr, " WARNING: failed to allocate %.2f MB of pinned memory\n " ,
32333264 size/1024.0 /1024.0 );
3265+ device->memory_usage [buf->memory_type_index ] -= buf->size ;
32343266 device->device .freeMemory (buf->device_memory );
32353267 device->device .destroyBuffer (buf->buffer );
32363268 return nullptr ;
0 commit comments