@@ -228,6 +228,9 @@ struct vk_device_struct {
228228 vk_pipeline pipeline_simpler_mul_mat_q6_k;
229229 vk_pipeline pipeline_simpler_mul_mat_q8_0;
230230
231+ vk_pipeline pipeline_simpler_soft_max_f16;
232+ vk_pipeline pipeline_simpler_soft_max_f32;
233+
231234 vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
232235 vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
233236 vk_pipeline pipeline_acc_f32;
@@ -518,6 +521,18 @@ struct vk_op_soft_max_push_constants {
518521 uint32_t nrows_x;
519522};
520523
524+ struct vk_op_simpler_soft_max_push_constants {
525+ int32_t ne00;
526+ int32_t ne01;
527+ int32_t ne02;
528+ float scale;
529+ float max_bias;
530+ float m0;
531+ float m1;
532+ uint32_t n_head_log2;
533+ int32_t mask;
534+ };
535+
521536struct vk_op_argsort_push_constants {
522537 uint32_t ncols;
523538 uint32_t ncols_pad;
@@ -2088,6 +2103,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
20882103 ggml_vk_create_pipeline (device, device->pipeline_simpler_mul_mat_q6_k , " simpler_mul_mat_q6_k" , simpler_mul_mat_q6_k_len, simpler_mul_mat_q6_k_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {2 , device->subgroup_size }, 1 );
20892104 ggml_vk_create_pipeline (device, device->pipeline_simpler_mul_mat_q8_0 , " simpler_mul_mat_q8_0" , simpler_mul_mat_q8_0_len, simpler_mul_mat_q8_0_data, " main" , 3 , 18 * sizeof (uint32_t ), {1 , 1 , 1 }, {(device->subgroup_size * 2 ) / 8 }, 1 );
20902105
2106+ ggml_vk_create_pipeline (device, device->pipeline_simpler_soft_max_f16 , " simpler_soft_max_f16" , simpler_soft_max_f16_len, simpler_soft_max_f16_data, " main" , 3 , sizeof (vk_op_simpler_soft_max_push_constants), {1 , 1 , 1 }, {}, 1 );
2107+ ggml_vk_create_pipeline (device, device->pipeline_simpler_soft_max_f32 , " simpler_soft_max_f32" , simpler_soft_max_f32_len, simpler_soft_max_f32_data, " main" , 3 , sizeof (vk_op_simpler_soft_max_push_constants), {1 , 1 , 1 }, {}, 1 );
2108+
20912109 ggml_vk_create_pipeline (device, device->pipeline_norm_f32 , " norm_f32" , norm_f32_len, norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
20922110 ggml_vk_create_pipeline (device, device->pipeline_group_norm_f32 , " group_norm_f32" , group_norm_f32_len, group_norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
20932111 ggml_vk_create_pipeline (device, device->pipeline_rms_norm_f32 , " rms_norm_f32" , rms_norm_f32_len, rms_norm_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, {}, 1 );
@@ -5440,6 +5458,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
54405458 case GGML_OP_SOFT_MAX:
54415459 GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
54425460
5461+ if (ctx->device ->simpler_shaders ) {
5462+ if (src1 && src1->type == GGML_TYPE_F16) {
5463+ return ctx->device ->pipeline_simpler_soft_max_f16 ;
5464+ }
5465+ return ctx->device ->pipeline_simpler_soft_max_f32 ;
5466+ }
5467+
54435468 if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
54445469 return src0->ne [0 ] > 1024 ? ctx->device ->pipeline_soft_max_f32_wg512 : ctx->device ->pipeline_soft_max_f32 ;
54455470 }
@@ -5738,9 +5763,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57385763 GGML_ASSERT (op_supports_incontiguous || (ggml_is_contiguous (src0) && (src1 == nullptr || ggml_is_contiguous (src1))));
57395764
57405765 switch (op) {
5766+ case GGML_OP_SOFT_MAX:
5767+ if (ctx->device ->simpler_shaders ) {
5768+ elements = { (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ], (uint32_t )src0->ne [3 ] };
5769+ break ;
5770+ }
5771+ // fall-through
57415772 case GGML_OP_NORM:
57425773 case GGML_OP_RMS_NORM:
5743- case GGML_OP_SOFT_MAX:
57445774 case GGML_OP_SUM_ROWS:
57455775 {
57465776 const uint32_t nr = ggml_nrows (src0);
@@ -6281,14 +6311,26 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
62816311 const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
62826312 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
62836313
6284- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX, {
6285- ncols,
6286- src1 != nullptr ? nrows_y : (uint32_t )0 ,
6287- scale, max_bias,
6288- m0, m1,
6289- n_head_log2,
6290- nrows_x,
6291- }, dryrun);
6314+ if (ctx->device ->simpler_shaders ) {
6315+ ggml_vk_op_f32<vk_op_simpler_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX, {
6316+ (int32_t ) src0->ne [0 ],
6317+ (int32_t ) src0->ne [1 ],
6318+ (int32_t ) src0->ne [2 ],
6319+ scale, max_bias,
6320+ m0, m1,
6321+ n_head_log2,
6322+ src1 == nullptr ? 0 : 1 ,
6323+ }, dryrun);
6324+ } else {
6325+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SOFT_MAX, {
6326+ ncols,
6327+ src1 != nullptr ? nrows_y : (uint32_t )0 ,
6328+ scale, max_bias,
6329+ m0, m1,
6330+ n_head_log2,
6331+ nrows_x,
6332+ }, dryrun);
6333+ }
62926334}
62936335
62946336static void ggml_vk_rope (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false ) {
0 commit comments