Skip to content

Commit a30024b

Browse files
committed
vulkan: add optimization for SSM scan
use subgroupAdd if available Signed-off-by: Giuseppe Scrivano <[email protected]>
1 parent cd3dc24 commit a30024b

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3607,8 +3607,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
36073607

36083608
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
36093609

3610-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
3611-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
3610+
const uint32_t ssm_variant = device->subgroup_arithmetic ? 1 : 0;
3611+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", arr_ssm_scan_f32_len[ssm_variant], arr_ssm_scan_f32_data[ssm_variant], "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
3612+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", arr_ssm_scan_f32_len[ssm_variant], arr_ssm_scan_f32_data[ssm_variant], "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
36123613

36133614
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
36143615

ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
#version 450
22

33
#extension GL_EXT_control_flow_attributes : require
4+
#if USE_SUBGROUP_ADD
5+
#extension GL_KHR_shader_subgroup_basic : require
6+
#extension GL_KHR_shader_subgroup_arithmetic : require
7+
#endif
48

59
#include "types.glsl"
610

@@ -97,10 +101,18 @@ void main() {
97101
}
98102

99103
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
104+
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
100105
const uint idx = (tid % SUBGROUP_SIZE) +
101106
D_STATE * (tid / SUBGROUP_SIZE) +
102107
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
103-
108+
#if USE_SUBGROUP_ADD_
109+
if (idx < SPLIT_H * D_STATE) {
110+
float sum = subgroupAdd(stateC[idx]);
111+
if (tid % SUBGROUP_SIZE == 0) {
112+
d[y_base_idx + i * stride_y + k] = sum;
113+
}
114+
}
115+
#else
104116
uint lane = tid % SUBGROUP_SIZE;
105117

106118
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
@@ -111,9 +123,9 @@ void main() {
111123
}
112124

113125
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
114-
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
115126
d[y_base_idx + i * stride_y + k] = stateC[idx];
116127
}
128+
#endif
117129
}
118130

119131
barrier();

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,7 @@ void process_shaders() {
917917
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
918918

919919
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
920+
string_to_spv("ssm_scan_f32_subgroup", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
920921

921922
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
922923

@@ -1043,6 +1044,13 @@ void write_output_files() {
10431044
}
10441045
}
10451046

1047+
hdr << "extern const void * arr_ssm_scan_f32_data[2];\n";
1048+
hdr << "extern const uint64_t arr_ssm_scan_f32_len[2];\n";
1049+
if (basename(input_filepath) == "ssm_scan.comp") {
1050+
src << "const void * arr_ssm_scan_f32_data[2] = {ssm_scan_f32_data, ssm_scan_f32_subgroup_data};\n";
1051+
src << "const uint64_t arr_ssm_scan_f32_len[2] = {ssm_scan_f32_len, ssm_scan_f32_subgroup_len};\n";
1052+
}
1053+
10461054
if (input_filepath == "") {
10471055
write_file_if_changed(target_hpp, hdr.str());
10481056
}

0 commit comments

Comments
 (0)