Skip to content

Commit ea92ae9

Browse files
committed
Merge branch 'master' of https://github.com/piDack/llama.cpp into glm_asr_support
2 parents ccdc372 + 77ad854 commit ea92ae9

25 files changed

+743
-59
lines changed

common/arg.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -724,7 +724,7 @@ static void add_rpc_devices(const std::string & servers) {
724724
}
725725
}
726726

727-
bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
727+
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map) {
728728
common_params dummy_params;
729729
common_params_context ctx_arg = common_params_parser_init(dummy_params, ex, nullptr);
730730

@@ -733,6 +733,9 @@ bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<comm
733733
for (const auto & arg : opt.args) {
734734
arg_to_options[arg] = &opt;
735735
}
736+
for (const auto & arg : opt.args_neg) {
737+
arg_to_options[arg] = &opt;
738+
}
736739
}
737740

738741
// TODO @ngxson : find a way to deduplicate this code

common/arg.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
115115

116116
// parse input arguments from CLI into a map
117117
// TODO: support repeated args in the future
118-
bool common_params_parse(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
118+
bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map<common_arg, std::string> & out_map);
119119

120120
// initialize argument parser context - used by test-arg-parser and preset
121121
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr);

examples/model-conversion/scripts/causal/run-org-model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def fn(_m, input, output):
200200
logits = outputs.logits
201201

202202
# Extract logits for the last token (next token prediction)
203-
last_logits = logits[0, -1, :].cpu().numpy()
203+
last_logits = logits[0, -1, :].float().cpu().numpy()
204204

205205
print(f"Logits shape: {logits.shape}")
206206
print(f"Last token logits shape: {last_logits.shape}")

ggml/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
5454
# TODO
5555
else()
5656
set(GGML_STANDALONE OFF)
57+
58+
if (NOT CMAKE_RUNTIME_OUTPUT_DIRECTORY)
59+
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
60+
endif()
5761
endif()
5862

5963
if (EMSCRIPTEN)

ggml/src/ggml-cpu/arch/arm/repack.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
#define UNUSED GGML_UNUSED
2626

27+
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD))
2728
static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
2829
int16x8_t * out_mins,
2930
int8_t * out_scales) {
@@ -46,6 +47,7 @@ static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in,
4647
scales_u32[1] = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4);
4748
memcpy(out_scales, scales_u32, 8);
4849
}
50+
#endif
4951

5052
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
5153
assert(QK8_0 == 32);

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 93 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ struct vk_device_struct {
659659
vk_pipeline pipeline_cos_f32;
660660
vk_pipeline pipeline_log[2];
661661
vk_pipeline pipeline_tri[2];
662+
vk_pipeline pipeline_diag[2];
662663
vk_pipeline pipeline_clamp_f32;
663664
vk_pipeline pipeline_pad_f32;
664665
vk_pipeline pipeline_roll_f32;
@@ -722,6 +723,11 @@ struct vk_device_struct {
722723
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
723724
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
724725
vk_pipeline pipeline_soft_max_back_f32;
726+
727+
vk_pipeline pipeline_soft_max_large1_f32, pipeline_soft_max_large1_f32_f16;
728+
vk_pipeline pipeline_soft_max_large2_f32, pipeline_soft_max_large2_f32_f16;
729+
vk_pipeline pipeline_soft_max_large3_f32, pipeline_soft_max_large3_f32_f16;
730+
725731
vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16, pipeline_rope_norm_f32_f16;
726732
vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16, pipeline_rope_neox_f32_f16;
727733
vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -3732,6 +3738,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
37323738
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
37333739
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
37343740
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3741+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
37353742

37363743
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
37373744
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -3919,6 +3926,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
39193926
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39203927
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39213928

3929+
ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3930+
ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
3931+
39223932
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
39233933

39243934
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@@ -3998,6 +4008,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
39984008
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
39994009
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
40004010

4011+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32, "soft_max_large1_f32", soft_max_large1_f32_len, soft_max_large1_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4012+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32, "soft_max_large2_f32", soft_max_large2_f32_len, soft_max_large2_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4013+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32, "soft_max_large3_f32", soft_max_large3_f32_len, soft_max_large3_f32_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4014+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large1_f32_f16, "soft_max_large1_f32_f16", soft_max_large1_f32_f16_len, soft_max_large1_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4015+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large2_f32_f16, "soft_max_large2_f32_f16", soft_max_large2_f32_f16_len, soft_max_large2_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4016+
ggml_vk_create_pipeline(device, device->pipeline_soft_max_large3_f32_f16, "soft_max_large3_f32_f16", soft_max_large3_f32_f16_len, soft_max_large3_f32_f16_data, "main", 6, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 128, 4 }, 1, true);
4017+
40014018
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40024019
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
40034020
ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -8278,6 +8295,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
82788295
switch (op) {
82798296
case GGML_OP_GET_ROWS:
82808297
GGML_ASSERT(src1->type == GGML_TYPE_I32);
8298+
if (src0->type == GGML_TYPE_I32) {
8299+
// i32 src only supports i32 result
8300+
GGML_ASSERT(dst->type == GGML_TYPE_I32);
8301+
return ctx->device->pipeline_get_rows[src0->type];
8302+
}
82818303
if (dst->type == GGML_TYPE_F16) {
82828304
return ctx->device->pipeline_get_rows[src0->type];
82838305
}
@@ -8404,6 +8426,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
84048426
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
84058427
}
84068428
return nullptr;
8429+
case GGML_OP_DIAG:
8430+
if (src0->type == dst->type &&
8431+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
8432+
return ctx->device->pipeline_diag[dst->type == GGML_TYPE_F16];
8433+
}
8434+
return nullptr;
84078435
case GGML_OP_CLAMP:
84088436
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
84098437
return ctx->device->pipeline_clamp_f32;
@@ -9097,6 +9125,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
90979125
case GGML_OP_COS:
90989126
case GGML_OP_LOG:
90999127
case GGML_OP_TRI:
9128+
case GGML_OP_DIAG:
91009129
case GGML_OP_CLAMP:
91019130
case GGML_OP_PAD:
91029131
case GGML_OP_ROLL:
@@ -9784,6 +9813,12 @@ static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const
97849813
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
97859814
}
97869815

9816+
static void ggml_vk_diag(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
9817+
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst));
9818+
9819+
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_DIAG, std::move(p));
9820+
}
9821+
97879822
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
97889823
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
97899824
p.param1 = ggml_get_op_params_f32(dst, 0);
@@ -10117,7 +10152,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
1011710152
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1011810153
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1011910154

10120-
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, {
10155+
vk_op_soft_max_push_constants pc {
1012110156
ncols,
1012210157
src1 != nullptr ? nrows_y : (uint32_t)0,
1012310158
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -10128,7 +10163,55 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
1012810163
n_head_log2,
1012910164
nrows_x,
1013010165
src2 != nullptr
10131-
});
10166+
};
10167+
10168+
if (ncols <= 16384) {
10169+
ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SOFT_MAX, std::move(pc));
10170+
} else {
10171+
10172+
vk_subbuffer buf_a = ggml_vk_tensor_subbuffer(ctx, src0);
10173+
vk_subbuffer buf_b = src1 ? ggml_vk_tensor_subbuffer(ctx, src1) : buf_a;
10174+
vk_subbuffer buf_c = src2 ? ggml_vk_tensor_subbuffer(ctx, src2) : buf_a;
10175+
vk_subbuffer buf_d = ggml_vk_tensor_subbuffer(ctx, dst);
10176+
10177+
uint32_t elems_per_wg = 128 * 4;
10178+
uint32_t num_wgs = CEIL_DIV(ncols, elems_per_wg);
10179+
size_t tmp_size = num_wgs * nrows_x * sizeof(float);
10180+
10181+
if (ctx->prealloc_size_x < tmp_size) {
10182+
ctx->prealloc_size_x = tmp_size;
10183+
ggml_vk_preallocate_buffers(ctx, subctx);
10184+
}
10185+
if (ctx->prealloc_size_y < tmp_size) {
10186+
ctx->prealloc_size_y = tmp_size;
10187+
ggml_vk_preallocate_buffers(ctx, subctx);
10188+
}
10189+
if (ctx->prealloc_x_need_sync || ctx->prealloc_y_need_sync) {
10190+
ggml_vk_sync_buffers(ctx, subctx);
10191+
}
10192+
10193+
vk_subbuffer buf_x = { ctx->prealloc_x, 0, tmp_size };
10194+
vk_subbuffer buf_y = { ctx->prealloc_y, 0, tmp_size };
10195+
10196+
std::array<uint32_t, 3> elements = { num_wgs, nrows_x, 1 };
10197+
10198+
vk_pipeline pipeline1 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large1_f32_f16 : ctx->device->pipeline_soft_max_large1_f32;
10199+
vk_pipeline pipeline2 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large2_f32_f16 : ctx->device->pipeline_soft_max_large2_f32;
10200+
vk_pipeline pipeline3 = src1 && src1->type == GGML_TYPE_F16 ? ctx->device->pipeline_soft_max_large3_f32_f16 : ctx->device->pipeline_soft_max_large3_f32;
10201+
10202+
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
10203+
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
10204+
ggml_pipeline_request_descriptor_sets(ctx, pipeline3, 1);
10205+
10206+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10207+
ggml_vk_sync_buffers(ctx, subctx);
10208+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10209+
ggml_vk_sync_buffers(ctx, subctx);
10210+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline3, { buf_a, buf_b, buf_c, buf_d, buf_x, buf_y }, pc, elements);
10211+
10212+
ctx->prealloc_x_need_sync = true;
10213+
ctx->prealloc_y_need_sync = true;
10214+
}
1013210215
}
1013310216

1013410217
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -11864,6 +11947,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
1186411947
case GGML_OP_TRI:
1186511948
ggml_vk_tri(ctx, compute_ctx, src0, node);
1186611949

11950+
break;
11951+
case GGML_OP_DIAG:
11952+
ggml_vk_diag(ctx, compute_ctx, src0, node);
11953+
1186711954
break;
1186811955
case GGML_OP_CLAMP:
1186911956
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -13883,6 +13970,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1388313970
case GGML_TYPE_IQ4_XS:
1388413971
case GGML_TYPE_IQ4_NL:
1388513972
case GGML_TYPE_MXFP4:
13973+
case GGML_TYPE_I32:
1388613974
return true;
1388713975
default:
1388813976
return false;
@@ -14007,6 +14095,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1400714095
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1400814096
case GGML_OP_LOG:
1400914097
case GGML_OP_TRI:
14098+
case GGML_OP_DIAG:
1401014099
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
1401114100
op->type == op->src[0]->type;
1401214101
case GGML_OP_ARGSORT:
@@ -14597,6 +14686,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
1459714686
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
1459814687
} else if (tensor->op == GGML_OP_TRI) {
1459914688
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
14689+
} else if (tensor->op == GGML_OP_DIAG) {
14690+
tensor_clone = ggml_diag(ggml_ctx, src_clone[0]);
1460014691
} else if (tensor->op == GGML_OP_CLAMP) {
1460114692
const float * params = (const float *)tensor->op_params;
1460214693
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#include "rte.glsl"
4+
#include "types.glsl"
5+
#include "generic_unary_head.glsl"
6+
7+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
8+
9+
void main() {
10+
const uint idx = get_idx();
11+
12+
if (idx >= p.ne) {
13+
return;
14+
}
15+
16+
const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L);
17+
const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10;
18+
const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L);
19+
const uint i12_offset = i12*p.ne11*p.ne10;
20+
const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L);
21+
const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10;
22+
23+
if (i10 == i11) {
24+
const float val = float(data_a[get_aoffset() + i13*p.nb03 + i12*p.nb02 + 0*p.nb01 + i10*p.nb00]);
25+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
26+
} else {
27+
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
28+
}
29+
}

ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ void main() {
2626
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
2727

2828
#if defined(DATA_A_BF16)
29-
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
29+
TEMP_TYPE v = TEMP_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
3030
#else
31-
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
31+
TEMP_TYPE v = TEMP_TYPE(data_a[a_offset + i00]);
3232
#endif
3333
#ifndef OPTIMIZATION_ERROR_WORKAROUND
3434
data_d[d_offset + i00] = D_TYPE(v);
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#version 450
2+
3+
#include "soft_max_large_common.glsl"
4+
5+
void main() {
6+
const uint tid = gl_LocalInvocationID.x;
7+
const uint rowx = gl_WorkGroupID.y;
8+
const uint wg_start = gl_WorkGroupID.x * BLOCK_SIZE * num_iters;
9+
10+
const uint32_t i03 = rowx / (p.ne01 * p.ne02);
11+
const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01;
12+
const uint32_t i01 = rowx % p.ne01;
13+
14+
uint rowy_start = 0;
15+
if (p.KY > 0) {
16+
rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13;
17+
}
18+
19+
if (rowx >= p.nrows_x) {
20+
return;
21+
}
22+
23+
float slope = get_slope(rowx);
24+
25+
// Find max
26+
FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
27+
28+
[[unroll]] for (uint col0 = wg_start, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
29+
const uint col = col0 + tid;
30+
31+
FLOAT_TYPE a = FLOAT_TYPE(0);
32+
if (col < p.KX) {
33+
a = data_a[rowx * p.KX + col];
34+
}
35+
36+
FLOAT_TYPE b = FLOAT_TYPE(0);
37+
if (p.KY > 0 && col < p.KX) {
38+
b = data_b[rowy_start + col];
39+
}
40+
41+
FLOAT_TYPE v = a * p.scale + slope * b;
42+
43+
if (col < p.KX) {
44+
max_val = max(max_val, v);
45+
}
46+
}
47+
48+
// reduce across the workgroup
49+
vals[tid] = max_val;
50+
barrier();
51+
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
52+
if (tid < s) {
53+
vals[tid] = max(vals[tid], vals[tid + s]);
54+
}
55+
barrier();
56+
}
57+
58+
if (tid == 0) {
59+
max_val = vals[0];
60+
data_m[rowx * gl_NumWorkGroups.x + gl_WorkGroupID.x] = max_val;
61+
}
62+
}

0 commit comments

Comments
 (0)