@@ -1355,6 +1355,7 @@ struct vk_instance_t {
13551355 PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
13561356
13571357 std::vector<size_t> device_indices;
1358+ std::vector<bool> device_supports_membudget;
13581359 vk_device devices[GGML_VK_MAX_DEVICES];
13591360};
13601361
@@ -4202,15 +4203,16 @@ static void ggml_vk_instance_init() {
42024203 vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
42034204 vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
42044205 vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
4205-
42064206 }
42074207
42084208 vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
42094209
4210+ std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
4211+
42104212 // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
42114213 char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
42124214 if (devices_env != nullptr) {
4213- size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices() .size();
4215+ size_t num_available_devices = devices .size();
42144216
42154217 std::string devices(devices_env);
42164218 std::replace(devices.begin(), devices.end(), ',', ' ');
@@ -4225,8 +4227,6 @@ static void ggml_vk_instance_init() {
42254227 vk_instance.device_indices.push_back(tmp);
42264228 }
42274229 } else {
4228- std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
4229-
42304230 // If no vulkan devices are found, return early
42314231 if (devices.empty()) {
42324232 GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
@@ -4331,6 +4331,19 @@ static void ggml_vk_instance_init() {
43314331 GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
43324332
43334333 for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
4334+ vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];
4335+ std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();
4336+
4337+ bool membudget_supported = false;
4338+ for (const auto & ext : extensionprops) {
4339+ if (std::string(ext.extensionName.data()) == VK_EXT_MEMORY_BUDGET_EXTENSION_NAME) {
4340+ membudget_supported = true;
4341+ break;
4342+ }
4343+ }
4344+
4345+ vk_instance.device_supports_membudget.push_back(membudget_supported);
4346+
43344347 ggml_vk_print_gpu_info(i);
43354348 }
43364349}
@@ -11441,23 +11454,16 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
1144111454
1144211455void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
1144311456 GGML_ASSERT(device < (int) vk_instance.device_indices.size());
11457+ GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
1144411458
1144511459 vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
1144611460 vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
11461+ bool membudget_supported = vk_instance.device_supports_membudget[device];
1144711462
11448- std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();
1144911463 vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
1145011464 vk::PhysicalDeviceMemoryProperties2 memprops2 = {};
11451- bool membudget_extension_supported = false;
11452-
11453- for (const auto & ext : extensionprops) {
11454- if (std::string(ext.extensionName.data()) == VK_EXT_MEMORY_BUDGET_EXTENSION_NAME) {
11455- membudget_extension_supported = true;
11456- break;
11457- }
11458- }
1145911465
11460- if (membudget_extension_supported ) {
11466+ if (membudget_supported ) {
1146111467 memprops2.pNext = &budgetprops;
1146211468 vkdev.getMemoryProperties2(&memprops2);
1146311469 }
@@ -11468,7 +11474,7 @@ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total
1146811474 if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
1146911475 *total = heap.size;
1147011476
11471- if (membudget_extension_supported && i < budgetprops.heapUsage.size()) {
11477+ if (membudget_supported && i < budgetprops.heapUsage.size()) {
1147211478 *free = *total - budgetprops.heapUsage[i];
1147311479 } else {
1147411480 *free = heap.size;
0 commit comments