Skip to content

Commit 140ef26

Browse files
committed
vulkan: request round-to-even for fp16 in im2col/rope_head
Vulkan doesn't mandate a specific rounding mode, but the shader_float_controls feature allows rounding mode to be requested if the implementation supports it.
1 parent 26a8406 commit 140ef26

File tree

4 files changed

+31
-5
lines changed

4 files changed

+31
-5
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ struct vk_device_struct {
168168
uint32_t subgroup_size;
169169
uint32_t shader_core_count;
170170
bool uma;
171+
bool float_controls_rte_fp16;
171172
bool coopmat2;
172173

173174
bool coopmat_support;
@@ -1922,17 +1923,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
19221923
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
19231924

19241925
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1925-
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1926-
19271926
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1928-
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1927+
1928+
if (device->float_controls_rte_fp16) {
1929+
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1930+
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1931+
} else {
1932+
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1933+
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1934+
}
19291935

19301936
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
19311937

19321938
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
19331939

19341940
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1935-
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1941+
if (device->float_controls_rte_fp16) {
1942+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1943+
} else {
1944+
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1945+
}
19361946

19371947
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
19381948

@@ -2013,11 +2023,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
20132023
vk::PhysicalDeviceDriverProperties driver_props;
20142024
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
20152025
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2026+
vk::PhysicalDeviceVulkan12Properties vk12_props;
20162027
props2.pNext = &props3;
20172028
props3.pNext = &subgroup_props;
20182029
subgroup_props.pNext = &driver_props;
2030+
driver_props.pNext = &vk12_props;
20192031

2020-
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
2032+
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
20212033

20222034
if (maintenance4_support) {
20232035
last_struct->pNext = (VkBaseOutStructure *)&props4;
@@ -2063,6 +2075,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
20632075
} else {
20642076
device->shader_core_count = 0;
20652077
}
2078+
device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
20662079

20672080
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
20682081

ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#version 450
22

33
#extension GL_EXT_shader_16bit_storage : require
4+
#extension GL_EXT_spirv_intrinsics: enable
5+
6+
#if RTE16
7+
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
8+
#endif
49

510
layout (push_constant) uniform parameter
611
{

ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
#include "types.comp"
22

33
#extension GL_EXT_shader_16bit_storage : require
4+
#extension GL_EXT_spirv_intrinsics: enable
5+
6+
#if RTE16
7+
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
8+
#endif
49

510
layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
611

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,16 +458,19 @@ void process_shaders() {
458458

459459
string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
460460
string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
461+
string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
461462

462463
string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
463464
string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
465+
string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}});
464466

465467
string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}});
466468

467469
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
468470

469471
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
470472
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
473+
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
471474

472475
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
473476

0 commit comments

Comments
 (0)