Skip to content

Commit d8c0a7b

Browse files
authored
vulkan: Fix mismatch in TOPK_MOE unit test (ggml-org#17541)
* Fix shader to support 2D workgroup mapping to a single subgroup * Set required_subgroup_size topk_moe shader requires static WARP_SIZE and actual subgroup size to match
1 parent 933414c commit d8c0a7b

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4174,9 +4174,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
41744174
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
41754175

41764176
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
4177-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
4178-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
4179-
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
4177+
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true, device->subgroup_size);
4178+
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true, device->subgroup_size);
4179+
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true, device->subgroup_size);
41804180
}
41814181

41824182
for (auto &c : compiles) {

ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,26 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit
7575
}
7676

7777
void main() {
78-
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
78+
const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID;
7979
if (row >= n_rows) {
8080
return;
8181
}
8282

8383
const uint logits_offset = n_experts * row;
8484
const uint weights_offset = n_expert_used * row;
8585
const uint ids_offset = n_experts * row;
86+
const uint lane = gl_SubgroupInvocationID;
8687

8788
float wt[experts_per_thread];
8889

8990
[[unroll]]
9091
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
91-
const uint expert = i + gl_LocalInvocationID.x;
92+
const uint expert = i + lane;
9293
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
9394
}
9495

9596
if (!late_softmax) {
96-
softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
97+
softmax_warp_inplace(wt, n_experts, lane, false);
9798
}
9899

99100
// at this point, each thread holds a portion of softmax,
@@ -111,11 +112,11 @@ void main() {
111112

112113
for (int k = 0; k < n_expert_used; k++) {
113114
float max_val = wt[0];
114-
uint max_expert = gl_LocalInvocationID.x;
115+
uint max_expert = lane;
115116

116117
[[unroll]]
117118
for (int i = 1; i < experts_per_thread; i++) {
118-
const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
119+
const uint expert = lane + i * WARP_SIZE;
119120
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
120121
max_val = wt[i];
121122
max_expert = expert;
@@ -132,11 +133,11 @@ void main() {
132133
}
133134
}
134135

135-
if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
136+
if ((k & (WARP_SIZE - 1)) == lane) {
136137
output_weights[k / WARP_SIZE] = max_val;
137138
}
138139

139-
if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
140+
if ((max_expert & (WARP_SIZE - 1)) == lane) {
140141
wt[max_expert / WARP_SIZE] = -INFINITY;
141142

142143
ids[ids_offset + k] = max_expert;
@@ -158,12 +159,12 @@ void main() {
158159
}
159160

160161
if (late_softmax) {
161-
softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
162+
softmax_warp_inplace(output_weights, n_expert_used, lane, true);
162163
}
163164

164165
[[unroll]]
165166
for (uint i = 0; i < experts_per_thread; ++i) {
166-
uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
167+
uint idx = i * WARP_SIZE + lane;
167168
if (idx < n_expert_used) {
168169
weights[weights_offset + idx] = output_weights[i];
169170
}

0 commit comments

Comments
 (0)