@@ -659,6 +659,7 @@ struct vk_device_struct {
659659 vk_pipeline pipeline_cos_f32;
660660 vk_pipeline pipeline_log[2];
661661 vk_pipeline pipeline_tri[2];
662+ vk_pipeline pipeline_diag[2];
662663 vk_pipeline pipeline_clamp_f32;
663664 vk_pipeline pipeline_pad_f32;
664665 vk_pipeline pipeline_roll_f32;
@@ -722,6 +723,11 @@ struct vk_device_struct {
722723 vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
723724 vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
724725 vk_pipeline pipeline_soft_max_back_f32;
726+
727+ vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
728+ vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
729+ vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
730+
725731 vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
726732 vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
727733 vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -3732,6 +3738,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
37323738 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
37333739 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
37343740 ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3741+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
37353742
37363743 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
37373744 ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -3919,6 +3926,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
39193926 ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39203927 ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39213928
3929+ ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3930+ ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3931+
39223932 ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39233933
39243934 ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -3998,6 +4008,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
39984008 ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
39994009 ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
40004010
4011+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4012+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4013+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4014+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4015+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4016+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4017+
40014018 ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40024019 ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40034020 ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -8278,6 +8295,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82788295 switch (op) {
82798296 case GGML_OP_GET_ROWS:
82808297 GGML_ASSERT(src1->type == GGML_TYPE_I32);
8298+ if (src0->type == GGML_TYPE_I32) {
8299+ // i32 src only supports i32 result
8300+ GGML_ASSERT(dst->type == GGML_TYPE_I32);
8301+ return ctx->device->pipeline_get_rows[src0->type];
8302+ }
82818303 if (dst->type == GGML_TYPE_F16) {
82828304 return ctx->device->pipeline_get_rows[src0->type];
82838305 }
@@ -8404,6 +8426,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
84048426 return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
84058427 }
84068428 return nullptr;
8429+ case GGML_OP_DIAG:
8430+ if (src0->type == dst->type &&
8431+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8432+ return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
8433+ }
8434+ return nullptr;
84078435 case GGML_OP_CLAMP:
84088436 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
84098437 return ctx->device->pipeline_clamp_f32;
@@ -9097,6 +9125,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
90979125 case GGML_OP_COS:
90989126 case GGML_OP_LOG:
90999127 case GGML_OP_TRI:
9128+ case GGML_OP_DIAG:
91009129 case GGML_OP_CLAMP:
91019130 case GGML_OP_PAD:
91029131 case GGML_OP_ROLL:
@@ -9784,6 +9813,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const
97849813 ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
97859814}
97869815
9816+ static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9817+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
9818+
9819+ ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
9820+ }
9821+
97879822static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
97889823 vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
97899824 p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -10117,7 +10152,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
1011710152 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1011810153 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1011910154
10120- ggml_vk_op_f32< vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
10155+ vk_op_soft_max_push_constants pc {
1012110156 ncols,
1012210157 src1 != nullptr ? nrows_y : (uint32_t)0,
1012310158 (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -10128,7 +10163,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
1012810163 n_head_log2,
1012910164 nrows_x,
1013010165 src2 != nullptr
10131- });
10166+ };
10167+
10168+ if (ncols <= 16384) {
10169+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
10170+ } else {
10171+
10172+ vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
10173+ vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
10174+ vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
10175+ vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
10176+
10177+ uint32_t elems_per_wg = 128 * 4;
10178+ uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
10179+ size_t tmp_size = num_wgs * nrows_x * sizeof(float);
10180+
10181+ if (ctx->prealloc_size_x < tmp_size) {
10182+ ctx->prealloc_size_x = tmp_size;
10183+ ggml_vk_preallocate_buffers(ctx, subctx);
10184+ }
10185+ if (ctx->prealloc_size_y < tmp_size) {
10186+ ctx->prealloc_size_y = tmp_size;
10187+ ggml_vk_preallocate_buffers(ctx, subctx);
10188+ }
10189+ if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
10190+ ggml_vk_sync_buffers(ctx, subctx);
10191+ }
10192+
10193+ vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
10194+ vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
10195+
10196+ std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };
10197+
10198+ vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
10199+ vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
10200+ vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
10201+
10202+ ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
10203+ ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
10204+ ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
10205+
10206+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10207+ ggml_vk_sync_buffers(ctx, subctx);
10208+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10209+ ggml_vk_sync_buffers(ctx, subctx);
10210+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10211+
10212+ ctx->prealloc_x_need_sync = true;
10213+ ctx->prealloc_y_need_sync = true;
10214+ }
1013210215}
1013310216
1013410217static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -11864,6 +11947,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1186411947 case GGML_OP_TRI:
1186511948 ggml_vk_tri(ctx, compute_ctx, src0, node);
1186611949
11950+ break;
11951+ case GGML_OP_DIAG:
11952+ ggml_vk_diag(ctx, compute_ctx, src0, node);
11953+
1186711954 break;
1186811955 case GGML_OP_CLAMP:
1186911956 ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -13883,6 +13970,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1388313970 case GGML_TYPE_IQ4_XS:
1388413971 case GGML_TYPE_IQ4_NL:
1388513972 case GGML_TYPE_MXFP4:
13973+ case GGML_TYPE_I32:
1388613974 return true;
1388713975 default:
1388813976 return false;
@@ -14007,6 +14095,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1400714095 return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1400814096 case GGML_OP_LOG:
1400914097 case GGML_OP_TRI:
14098+ case GGML_OP_DIAG:
1401014099 return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1401114100 op->type == op->src[0]->type;
1401214101 case GGML_OP_ARGSORT:
@@ -14597,6 +14686,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1459714686 tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
1459814687 } else if (tensor->op == GGML_OP_TRI) {
1459914688 tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
14689+ } else if (tensor->op == GGML_OP_DIAG) {
14690+ tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
1460014691 } else if (tensor->op == GGML_OP_CLAMP) {
1460114692 const float * params = (const float *)tensor->op_params;
1460214693 tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
0 commit comments