@@ -168,6 +168,7 @@ struct vk_device_struct {
168168 uint32_t subgroup_size;
169169 uint32_t shader_core_count;
170170 bool uma;
171+ bool float_controls_rte_fp16;
171172 bool coopmat2;
172173
173174 bool coopmat_support;
@@ -1922,17 +1923,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
19221923 ggml_vk_create_pipeline (device, device->pipeline_soft_max_f32_f16_wg512 , " soft_max_f32_f16_wg512" , soft_max_f32_f16_len, soft_max_f32_f16_data, " main" , 3 , sizeof (vk_op_soft_max_push_constants), {1 , 1 , 1 }, { 512 }, 1 );
19231924
19241925 ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f32 , " rope_norm_f32" , rope_norm_f32_len, rope_norm_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1925- ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16" , rope_norm_f16_len, rope_norm_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1926-
19271926 ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f32 , " rope_neox_f32" , rope_neox_f32_len, rope_neox_f32_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1928- ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f16 , " rope_neox_f16" , rope_neox_f16_len, rope_neox_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1927+
1928+ if (device->float_controls_rte_fp16 ) {
1929+ ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16" , rope_norm_f16_rte_len, rope_norm_f16_rte_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1930+ ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f16 , " rope_neox_f16" , rope_neox_f16_rte_len, rope_neox_f16_rte_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1931+ } else {
1932+ ggml_vk_create_pipeline (device, device->pipeline_rope_norm_f16 , " rope_norm_f16" , rope_norm_f16_len, rope_norm_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1933+ ggml_vk_create_pipeline (device, device->pipeline_rope_neox_f16 , " rope_neox_f16" , rope_neox_f16_len, rope_neox_f16_data, " main" , 4 , sizeof (vk_op_rope_push_constants), {1 , 512 , 1 }, {}, 1 );
1934+ }
19291935
19301936 ggml_vk_create_pipeline (device, device->pipeline_argsort_f32 , " argsort_f32" , argsort_f32_len, argsort_f32_data, " main" , 2 , sizeof (vk_op_argsort_push_constants), {1024 , 1 , 1 }, {}, 1 );
19311937
19321938 ggml_vk_create_pipeline (device, device->pipeline_sum_rows_f32 , " sum_rows_f32" , sum_rows_f32_len, sum_rows_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
19331939
19341940 ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
1935- ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
1941+ if (device->float_controls_rte_fp16 ) {
1942+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
1943+ } else {
1944+ ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_len, im2col_f32_f16_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {256 , 1 , 1 }, {}, 1 );
1945+ }
19361946
19371947 ggml_vk_create_pipeline (device, device->pipeline_timestep_embedding_f32 , " timestep_embedding_f32" , timestep_embedding_f32_len, timestep_embedding_f32_data, " main" , 2 , sizeof (vk_op_timestep_embedding_push_constants), {256 , 1 , 1 }, {}, 1 );
19381948
@@ -2013,11 +2023,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
20132023 vk::PhysicalDeviceDriverProperties driver_props;
20142024 vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
20152025 vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2026+ vk::PhysicalDeviceVulkan12Properties vk12_props;
20162027 props2.pNext = &props3;
20172028 props3.pNext = &subgroup_props;
20182029 subgroup_props.pNext = &driver_props;
2030+ driver_props.pNext = &vk12_props;
20192031
2020- VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props ;
2032+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props ;
20212033
20222034 if (maintenance4_support) {
20232035 last_struct->pNext = (VkBaseOutStructure *)&props4;
@@ -2063,6 +2075,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
20632075 } else {
20642076 device->shader_core_count = 0 ;
20652077 }
2078+ device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16 ;
20662079
20672080 const bool force_disable_f16 = getenv (" GGML_VK_DISABLE_F16" ) != nullptr ;
20682081
0 commit comments