Skip to content

Commit 21dbbe5

Browse files
committed
vulkan: Further soft_max optimizations
Restore the workgroup size of 512 case, use it for >1024. Use unrollable loops for more iteration counts.
1 parent c7b8ab7 commit 21dbbe5

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ struct vk_device_struct {
218218
vk_pipeline pipeline_tanh_f32;
219219
vk_pipeline pipeline_diag_mask_inf_f32;
220220
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
221+
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
221222
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
222223
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
223224
vk_pipeline pipeline_argsort_f32;
@@ -1498,7 +1499,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
14981499
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
14991500

15001501
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1502+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
15011503
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1504+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
15021505

15031506
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
15041507
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -3933,10 +3936,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
39333936
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
39343937

39353938
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3936-
return ctx->device->pipeline_soft_max_f32;
3939+
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
39373940
}
39383941
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
3939-
return ctx->device->pipeline_soft_max_f32_f16;
3942+
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
39403943
}
39413944
return nullptr;
39423945
case GGML_OP_ROPE:

ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -152,30 +152,30 @@ void main() {
152152
// instantiate the soft_max function for several different
153153
// dimensions, to allow loop unrolling
154154
uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;
155-
switch (num_blocks) {
156-
case 1:
157-
soft_max(1);
158-
break;
159-
case 2:
160-
soft_max(2);
161-
break;
162-
case 3:
163-
soft_max(3);
164-
break;
165-
case 4:
166-
soft_max(4);
167-
break;
168-
case 5:
169-
case 6:
170-
case 7:
171-
case 8:
172-
soft_max(8);
173-
break;
174-
case 16:
175-
soft_max(16);
176-
break;
177-
default:
155+
if (num_blocks > 32) {
178156
soft_max(num_blocks);
179-
break;
157+
} else if (num_blocks > 16) {
158+
soft_max(32);
159+
} else if (num_blocks > 8) {
160+
soft_max(16);
161+
} else if (num_blocks > 4) {
162+
soft_max(8);
163+
} else {
164+
switch (num_blocks) {
165+
case 1:
166+
soft_max(1);
167+
break;
168+
case 2:
169+
soft_max(2);
170+
break;
171+
case 3:
172+
soft_max(3);
173+
break;
174+
case 4:
175+
soft_max(4);
176+
break;
177+
default:
178+
break;
179+
}
180180
}
181181
}

0 commit comments

Comments
 (0)