Skip to content

Commit 3832f54

Browse files
Merge pull request #231 from menloresearch/update-dev-from-master-2025-09-01-00-42
Sync master with upstream release b6337
2 parents 358fecc + 0d161f0 commit 3832f54

22 files changed

+540
-257
lines changed

ci/run.sh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -386,10 +386,10 @@ function gg_run_open_llama_7b_v2 {
386386

387387
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
388388

389-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
390-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
391-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
392-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
389+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
390+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
391+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
392+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
393393

394394
function check_ppl {
395395
qnt="$1"
@@ -520,8 +520,8 @@ function gg_run_pythia_1_4b {
520520

521521
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
522522

523-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
524-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
523+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
524+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
525525

526526
function check_ppl {
527527
qnt="$1"
@@ -651,10 +651,10 @@ function gg_run_pythia_2_8b {
651651

652652
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
653653

654-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
655-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
656-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
657-
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
654+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
655+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
656+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
657+
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
658658

659659
function check_ppl {
660660
qnt="$1"

common/arg.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2962,20 +2962,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
29622962
params.endpoint_metrics = true;
29632963
}
29642964
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
2965-
add_opt(common_arg(
2966-
{"--slots"},
2967-
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
2968-
[](common_params & params) {
2969-
params.endpoint_slots = true;
2970-
}
2971-
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
29722965
add_opt(common_arg(
29732966
{"--props"},
29742967
string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"),
29752968
[](common_params & params) {
29762969
params.endpoint_props = true;
29772970
}
29782971
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS"));
2972+
add_opt(common_arg(
2973+
{"--slots"},
2974+
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
2975+
[](common_params & params) {
2976+
params.endpoint_slots = true;
2977+
}
2978+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
29792979
add_opt(common_arg(
29802980
{"--no-slots"},
29812981
"disables slots monitoring endpoint",

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ struct common_params {
444444

445445
// "advanced" endpoints are disabled by default for better security
446446
bool webui = true;
447-
bool endpoint_slots = false;
447+
bool endpoint_slots = true;
448448
bool endpoint_props = false; // only control POST requests, not GET
449449
bool endpoint_metrics = false;
450450

ggml/include/ggml-backend.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ extern "C" {
307307
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
308308
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
309309

310+
// Split graph without allocating it
311+
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
312+
310313
// Allocate and compute graph on the backend scheduler
311314
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
312315
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);

ggml/src/ggml-backend.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
902902
}
903903

904904
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
905-
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
905+
void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
906906
// reset splits
907907
sched->n_splits = 0;
908908
sched->n_graph_inputs = 0;
@@ -1687,6 +1687,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
16871687
GGML_ASSERT(sched);
16881688
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
16891689

1690+
ggml_backend_sched_reset(sched);
1691+
16901692
ggml_backend_sched_synchronize(sched);
16911693

16921694
ggml_backend_sched_split_graph(sched, measure_graph);

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -523,13 +523,6 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
523523
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
524524
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
525525
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
526-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
527-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
528-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
529-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
530-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
531-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
532-
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
533526
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
534527
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
535528
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
@@ -1562,13 +1555,6 @@ @implementation GGMLMetalClass
15621555
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
15631556
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
15641557
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
1565-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
1566-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
1567-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
1568-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
1569-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
1570-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
1571-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
15721558
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
15731559
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
15741560
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
@@ -1909,9 +1895,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
19091895
case GGML_OP_ARANGE:
19101896
return true;
19111897
case GGML_OP_FLASH_ATTN_EXT:
1912-
if (op->src[0]->ne[0] == 32) {
1913-
// head size == 32 (e.g. bert-bge-small)
1914-
// TODO: not sure if it is worth adding kernels for this size
1898+
// for new head sizes, add checks here
1899+
if (op->src[0]->ne[0] != 40 &&
1900+
op->src[0]->ne[0] != 64 &&
1901+
op->src[0]->ne[0] != 80 &&
1902+
op->src[0]->ne[0] != 96 &&
1903+
op->src[0]->ne[0] != 112 &&
1904+
op->src[0]->ne[0] != 128 &&
1905+
op->src[0]->ne[0] != 192 &&
1906+
op->src[0]->ne[0] != 256) {
19151907
return false;
19161908
}
19171909
if (op->src[0]->ne[0] == 576) {
@@ -5138,10 +5130,8 @@ static int ggml_metal_encode_node(
51385130

51395131
bool use_vec_kernel = false;
51405132

5141-
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
5142-
// for now avoiding mainly to keep the number of templates/kernels a bit lower
5143-
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
5144-
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
5133+
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
5134+
if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
51455135
switch (src1->type) {
51465136
case GGML_TYPE_F16:
51475137
{
@@ -5329,24 +5319,6 @@ static int ggml_metal_encode_node(
53295319
use_vec_kernel = true;
53305320

53315321
switch (ne00) {
5332-
case 40:
5333-
{
5334-
switch (src1->type) {
5335-
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
5336-
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
5337-
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
5338-
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
5339-
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
5340-
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
5341-
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
5342-
default:
5343-
{
5344-
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
5345-
GGML_LOG_ERROR("add template specialization for this type\n");
5346-
GGML_ABORT("add template specialization for this type");
5347-
}
5348-
}
5349-
} break;
53505322
case 64:
53515323
{
53525324
switch (src1->type) {

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4803,6 +4803,9 @@ kernel void kernel_flash_attn_ext_vec(
48034803
ushort3 ntg[[threads_per_threadgroup]],
48044804
ushort tiisg[[thread_index_in_simdgroup]],
48054805
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
4806+
static_assert(DK % 32 == 0, "DK must be divisible by 32");
4807+
static_assert(DV % 32 == 0, "DV must be divisible by 32");
4808+
48064809
const short nsg = ntg.y; // number of simdgroups
48074810
const short iwg = tgpig[2]%nwg;
48084811

@@ -5160,16 +5163,6 @@ kernel void kernel_flash_attn_ext_vec(
51605163

51615164
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
51625165

5163-
template [[host_name("kernel_flash_attn_ext_vec_f16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 40, 40, 8>;
5164-
#if defined(GGML_METAL_USE_BF16)
5165-
template [[host_name("kernel_flash_attn_ext_vec_bf16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 40, 40, 8>;
5166-
#endif
5167-
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 40, 40, 8>;
5168-
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 40, 40, 8>;
5169-
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 40, 40, 8>;
5170-
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 40, 40, 8>;
5171-
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 40, 40, 8>;
5172-
51735166
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
51745167
#if defined(GGML_METAL_USE_BF16)
51755168
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;

0 commit comments

Comments
 (0)