@@ -226,6 +226,9 @@ struct vk_device_struct {
226226 vk_pipeline pipeline_simpler_mul_mat_q6_k;
227227 vk_pipeline pipeline_simpler_mul_mat_q8_0;
228228
229+ vk_pipeline pipeline_simpler_soft_max_f16;
230+ vk_pipeline pipeline_simpler_soft_max_f32;
231+
229232 vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
230233 vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
231234 vk_pipeline pipeline_acc_f32;
@@ -516,6 +519,18 @@ struct vk_op_soft_max_push_constants {
516519 uint32_t nrows_x;
517520};
518521
522+ struct vk_op_simpler_soft_max_push_constants {
523+ int32_t ne00;
524+ int32_t ne01;
525+ int32_t ne02;
526+ float scale;
527+ float max_bias;
528+ float m0;
529+ float m1;
530+ uint32_t n_head_log2;
531+ int32_t mask;
532+ };
533+
519534struct vk_op_argsort_push_constants {
520535 uint32_t ncols;
521536 uint32_t ncols_pad;
@@ -1983,6 +1998,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
19831998 ggml_vk_create_pipeline (device, device->pipeline_simpler_mul_mat_q6_k , " simpler_mul_mat_q6_k" , simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {2 , device->subgroup_size }, 1 );
19841999 ggml_vk_create_pipeline (device, device->pipeline_simpler_mul_mat_q8_0 , " simpler_mul_mat_q8_0" , simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {(device->subgroup_size * 2 ) / 8 }, 1 );
19852000
2001+ ggml_vk_create_pipeline (device, device->pipeline_simpler_soft_max_f16 , " simpler_soft_max_f16" , simpler_soft_max_f16_len, simpler_soft_max_f16_data, " main" , 3 , sizeof (vk_op_simpler_soft_max_push_constants), {1 , 1 , 1 }, {}, 1 );
2002+ ggml_vk_create_pipeline (device, device->pipeline_simpler_soft_max_f32 , " simpler_soft_max_f32" , simpler_soft_max_f32_len, simpler_soft_max_f32_data, " main" , 3 , sizeof (vk_op_simpler_soft_max_push_constants), {1 , 1 , 1 }, {}, 1 );
2003+
19862004 ggml_vk_create_pipeline (device, device->pipeline_norm_f32 , " norm_f32" , norm_f32_len, norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
19872005 ggml_vk_create_pipeline (device, device->pipeline_group_norm_f32 , " group_norm_f32" , group_norm_f32_len, group_norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
19882006 ggml_vk_create_pipeline (device, device->pipeline_rms_norm_f32 , " rms_norm_f32" , rms_norm_f32_len, rms_norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
@@ -5286,6 +5304,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52865304 case GGML_OP_SOFT_MAX:
52875305 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
52885306
5307+ if (ctx->device ->simpler_shaders ) {
5308+ if (src1 && src1->type == GGML_TYPE_F16) {
5309+ return ctx->device ->pipeline_simpler_soft_max_f16 ;
5310+ }
5311+ return ctx->device ->pipeline_simpler_soft_max_f32 ;
5312+ }
5313+
52895314 if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
52905315 return src0->ne [0 ] > 1024 ? ctx->device ->pipeline_soft_max_f32_wg512 : ctx->device ->pipeline_soft_max_f32 ;
52915316 }
@@ -5584,9 +5609,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
55845609 GGML_ASSERT (op_supports_incontiguous || (ggml_is_contiguous (src0) && (src1 == nullptr || ggml_is_contiguous (src1))));
55855610
55865611 switch (op) {
5612+ case GGML_OP_SOFT_MAX:
5613+ if (ctx->device ->simpler_shaders ) {
5614+ elements = { (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ], (uint32_t )src0->ne [3 ] };
5615+ break ;
5616+ }
5617+ // fall-through
55875618 case GGML_OP_NORM:
55885619 case GGML_OP_RMS_NORM:
5589- case GGML_OP_SOFT_MAX:
55905620 case GGML_OP_SUM_ROWS:
55915621 {
55925622 const uint32_t nr = ggml_nrows (src0);
@@ -6127,14 +6157,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
61276157 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
61286158 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
61296159
6130- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX, {
6131- ncols,
6132- src1 != nullptr ? nrows_y : (uint32_t )0 ,
6133- scale, max_bias,
6134- m0, m1,
6135- n_head_log2,
6136- nrows_x,
6137- }, dryrun);
6160+ if (ctx->device ->simpler_shaders ) {
6161+ ggml_vk_op_f32<vk_op_simpler_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX, {
6162+ (int32_t ) src0->ne [0 ],
6163+ (int32_t ) src0->ne [1 ],
6164+ (int32_t ) src0->ne [2 ],
6165+ scale, max_bias,
6166+ m0, m1,
6167+ n_head_log2,
6168+ src1 == nullptr ? 0 : 1 ,
6169+ }, dryrun);
6170+ } else {
6171+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX, {
6172+ ncols,
6173+ src1 != nullptr ? nrows_y : (uint32_t )0 ,
6174+ scale, max_bias,
6175+ m0, m1,
6176+ n_head_log2,
6177+ nrows_x,
6178+ }, dryrun);
6179+ }
61386180}
61396181
61406182static void ggml_vk_rope (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false ) {
0 commit comments