Skip to content

Commit 6784772

Browse files
committed
fix: detect and cache memory budget extension availability on init
1 parent ea5d796 commit 6784772

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,7 @@ struct vk_instance_t {
13551355
PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {};
13561356

13571357
std::vector<size_t> device_indices;
1358+
std::vector<bool> device_supports_membudget;
13581359
vk_device devices[GGML_VK_MAX_DEVICES];
13591360
};
13601361

@@ -4202,15 +4203,16 @@ static void ggml_vk_instance_init() {
42024203
vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT");
42034204
vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT");
42044205
vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT");
4205-
42064206
}
42074207

42084208
vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr;
42094209

4210+
std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
4211+
42104212
// Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan
42114213
char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES");
42124214
if (devices_env != nullptr) {
4213-
size_t num_available_devices = vk_instance.instance.enumeratePhysicalDevices().size();
4215+
size_t num_available_devices = devices.size();
42144216

42154217
std::string devices(devices_env);
42164218
std::replace(devices.begin(), devices.end(), ',', ' ');
@@ -4225,8 +4227,6 @@ static void ggml_vk_instance_init() {
42254227
vk_instance.device_indices.push_back(tmp);
42264228
}
42274229
} else {
4228-
std::vector<vk::PhysicalDevice> devices = vk_instance.instance.enumeratePhysicalDevices();
4229-
42304230
// If no vulkan devices are found, return early
42314231
if (devices.empty()) {
42324232
GGML_LOG_INFO("ggml_vulkan: No devices found.\n");
@@ -4331,6 +4331,19 @@ static void ggml_vk_instance_init() {
43314331
GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
43324332

43334333
for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
4334+
vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]];
4335+
std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();
4336+
4337+
bool membudget_supported = false;
4338+
for (const auto & ext : extensionprops) {
4339+
if (std::string(ext.extensionName.data()) == VK_EXT_MEMORY_BUDGET_EXTENSION_NAME) {
4340+
membudget_supported = true;
4341+
break;
4342+
}
4343+
}
4344+
4345+
vk_instance.device_supports_membudget.push_back(membudget_supported);
4346+
43344347
ggml_vk_print_gpu_info(i);
43354348
}
43364349
}
@@ -11441,23 +11454,16 @@ void ggml_backend_vk_get_device_description(int device, char * description, size
1144111454

1144211455
void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
1144311456
GGML_ASSERT(device < (int) vk_instance.device_indices.size());
11457+
GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size());
1144411458

1144511459
vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
1144611460
vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
11461+
bool membudget_supported = vk_instance.device_supports_membudget[device];
1144711462

11448-
std::vector<vk::ExtensionProperties> extensionprops = vkdev.enumerateDeviceExtensionProperties();
1144911463
vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops;
1145011464
vk::PhysicalDeviceMemoryProperties2 memprops2 = {};
11451-
bool membudget_extension_supported = false;
11452-
11453-
for (const auto & ext : extensionprops) {
11454-
if (std::string(ext.extensionName.data()) == VK_EXT_MEMORY_BUDGET_EXTENSION_NAME) {
11455-
membudget_extension_supported = true;
11456-
break;
11457-
}
11458-
}
1145911465

11460-
if (membudget_extension_supported) {
11466+
if (membudget_supported) {
1146111467
memprops2.pNext = &budgetprops;
1146211468
vkdev.getMemoryProperties2(&memprops2);
1146311469
}
@@ -11468,7 +11474,7 @@ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total
1146811474
if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
1146911475
*total = heap.size;
1147011476

11471-
if (membudget_extension_supported && i < budgetprops.heapUsage.size()) {
11477+
if (membudget_supported && i < budgetprops.heapUsage.size()) {
1147211478
*free = *total - budgetprops.heapUsage[i];
1147311479
} else {
1147411480
*free = heap.size;

0 commit comments

Comments
 (0)