diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 493ee9c9a44..0af127fabb0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4353,6 +4353,7 @@ static vk_device ggml_vk_get_device(size_t idx) { device->coopmat_support = false; device->integer_dot_product = false; bool bfloat16_support = false; + bool buffer_device_address_khr = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -4397,6 +4398,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) { device->memory_priority = true; + } else if (strcmp("VK_KHR_buffer_device_address", properties.extensionName) == 0) { + buffer_device_address_khr = true; } } @@ -4791,7 +4794,9 @@ static vk_device ggml_vk_get_device(size_t idx) { throw std::runtime_error("Unsupported device"); } - device_extensions.push_back("VK_KHR_16bit_storage"); + if (fp16_storage) { + device_extensions.push_back("VK_KHR_16bit_storage"); + } #ifdef GGML_VULKAN_VALIDATE device_extensions.push_back("VK_KHR_shader_non_semantic_info"); @@ -4801,6 +4806,11 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_float16_int8"); } + // Required for physical devices that only support Vulkan 1.1 + if (device->buffer_device_address && buffer_device_address_khr) { + device_extensions.push_back("VK_KHR_buffer_device_address"); + } + #if defined(VK_KHR_cooperative_matrix) if (device->coopmat_support) { // Query supported shapes @@ -4924,6 +4934,10 @@ static vk_device ggml_vk_get_device(size_t idx) { device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); + // optionally initialize the dispatcher with a vk::Device to get + // device-specific function pointers (needed on Android) + VULKAN_HPP_DEFAULT_DISPATCHER.init(device->device); + // Queues ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);