@@ -328,6 +328,7 @@ struct vk_device_struct {
328328 uint64_t max_memory_allocation_size;
329329 uint64_t suballocation_block_size;
330330 bool fp16;
331+ bool bf16;
331332 bool pipeline_robustness;
332333 vk::Device device;
333334 uint32_t vendor_id;
@@ -3273,6 +3274,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
32733274
32743275 device->fp16 = device->fp16 && vk12_features.shaderFloat16;
32753276
3277+ #if defined(VK_KHR_shader_bfloat16)
3278+ device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
3279+ #else
3280+ device->bf16 = false;
3281+ #endif
3282+
32763283 device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
32773284
32783285 if (device->subgroup_size_control) {
@@ -3615,6 +3622,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
36153622 bool coopmat_support = false;
36163623 bool coopmat2_support = false;
36173624 bool integer_dot_product = false;
3625+ bool bfloat16_support = false;
36183626
36193627 for (auto properties : ext_props) {
36203628 if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -3635,6 +3643,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
36353643 } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
36363644 !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
36373645 integer_dot_product = true;
3646+ #endif
3647+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3648+ } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
3649+ !getenv("GGML_VK_DISABLE_BFLOAT16")) {
3650+ bfloat16_support = true;
36383651#endif
36393652 }
36403653 }
@@ -3701,10 +3714,25 @@ static void ggml_vk_print_gpu_info(size_t idx) {
37013714 last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
37023715 }
37033716
3717+ #if defined(VK_KHR_shader_bfloat16)
3718+ VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
3719+ bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
3720+ if (bfloat16_support) {
3721+ last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
3722+ last_struct = (VkBaseOutStructure *)&bfloat16_features;
3723+ }
3724+ #endif
3725+
37043726 vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
37053727
37063728 fp16 = fp16 && vk12_features.shaderFloat16;
37073729
3730+ #if defined(VK_KHR_shader_bfloat16)
3731+ bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type;
3732+ #else
3733+ bool bf16 = false;
3734+ #endif
3735+
37083736 uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
37093737 const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
37103738 const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
@@ -3722,8 +3750,8 @@ static void ggml_vk_print_gpu_info(size_t idx) {
37223750 std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
37233751
37243752 std::string device_name = props2.properties.deviceName.data();
3725- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
3726- idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
3753+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
3754+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
37273755 props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
37283756
37293757 if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
0 commit comments