Skip to content

Commit 8423d01

Browse files
authored
vulkan: Optimize SSM_SCAN (ggml-org#16645)
1 parent 5cca254 commit 8423d01

File tree

3 files changed

+42
-21
lines changed

3 files changed

+42
-21
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3623,8 +3623,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
36233623

36243624
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
36253625

3626-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
3627-
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
3626+
if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
3627+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3628+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3629+
} else {
3630+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3631+
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3632+
}
36283633

36293634
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
36303635

ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#version 450
22

33
#extension GL_EXT_control_flow_attributes : require
4+
#if USE_SUBGROUP_ADD
5+
#extension GL_KHR_shader_subgroup_arithmetic : enable
6+
#endif
47

58
#include "types.glsl"
69

@@ -84,35 +87,47 @@ void main() {
8487
}
8588

8689
barrier();
87-
for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) {
88-
[[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
89-
const uint k = (tid % (w >> 1)) +
90-
(D_STATE * (tid / (w >> 1))) +
91-
j * D_STATE * (D_STATE / (w >> 1));
92-
if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) {
93-
stateC[k] += stateC[k + (w >> 1)];
90+
[[unroll]]
91+
for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) {
92+
[[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) {
93+
const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w);
94+
if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) {
95+
stateC[k] += stateC[k + w];
9496
}
9597
}
9698
barrier();
9799
}
98100

99-
[[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) {
101+
[[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) {
100102
const uint idx = (tid % SUBGROUP_SIZE) +
101103
D_STATE * (tid / SUBGROUP_SIZE) +
102104
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
105+
const uint max_idx = SUBGROUP_SIZE - 1 +
106+
D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) +
107+
j * D_STATE * (D_STATE / SUBGROUP_SIZE);
103108

104-
uint lane = tid % SUBGROUP_SIZE;
105-
106-
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
107-
if (idx + offset < SPLIT_H * D_STATE) {
108-
stateC[idx] += stateC[idx + offset];
109+
if (idx < SPLIT_H * D_STATE ||
110+
max_idx < SPLIT_H * D_STATE) {
111+
float sc;
112+
#if USE_SUBGROUP_ADD
113+
sc = stateC[idx];
114+
sc = subgroupAdd(sc);
115+
#else
116+
[[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) {
117+
if (idx + offset < SPLIT_H * D_STATE) {
118+
stateC[idx] += stateC[idx + offset];
119+
}
120+
barrier();
109121
}
110-
barrier();
111-
}
122+
if (tid % SUBGROUP_SIZE == 0) {
123+
sc = stateC[idx];
124+
}
125+
#endif
112126

113-
if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) {
114-
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
115-
d[y_base_idx + i * stride_y + k] = stateC[idx];
127+
if (tid % SUBGROUP_SIZE == 0) {
128+
const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE);
129+
d[y_base_idx + i * stride_y + k] = sc;
130+
}
116131
}
117132
}
118133

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,8 @@ void process_shaders() {
916916
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
917917
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
918918

919-
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
919+
string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}});
920+
string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}});
920921

921922
string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
922923

0 commit comments

Comments
 (0)