From c12bbde37258c6d053322d1cedd4a9672109ae58 Mon Sep 17 00:00:00 2001 From: Diego Devesa Date: Fri, 25 Jul 2025 01:07:26 -0700 Subject: [PATCH 1/9] sched : fix multiple evaluations of the same graph with pipeline parallelism (#14855) ggml-ci --- ggml/src/ggml-backend.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index b7498b8d40238..eaf41e5a6c84d 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -647,6 +647,7 @@ struct ggml_backend_sched { // pipeline parallelism support int n_copies; int cur_copy; + int next_copy; ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; int n_graph_inputs; @@ -1433,8 +1434,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } - sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies; - return GGML_STATUS_SUCCESS; } @@ -1535,10 +1534,10 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) { bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); - ggml_backend_sched_split_graph(sched, measure_graph); - ggml_backend_sched_synchronize(sched); + ggml_backend_sched_split_graph(sched, measure_graph); + if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { return false; } @@ -1550,6 +1549,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); + GGML_ASSERT(!sched->is_alloc); + + sched->cur_copy = sched->next_copy; + sched->next_copy = (sched->next_copy + 1) % sched->n_copies; ggml_backend_sched_split_graph(sched, graph); @@ -1590,7 +1593,7 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { // if the graph is not already allocated, always use copy 0 after a synchronization // this ensures that during generation the same copy is used every time, // which avoids changes in the graph that could cause CUDA or other graphs to be disabled - sched->cur_copy = 0; + sched->next_copy = 0; } } From 64bf1c3744053cf7def10aeed21ff48883ee755b Mon Sep 17 00:00:00 2001 From: Chris Rohlf Date: Fri, 25 Jul 2025 06:17:02 -0400 Subject: [PATCH 2/9] rpc : check for null buffers in get/set/copy tensor endpoints (#14868) --- ggml/src/ggml-rpc/ggml-rpc.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index f468f796d5773..29bc421d58f5c 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1055,7 +1055,7 @@ bool rpc_server::set_tensor(const std::vector & input) { GGML_ASSERT(ctx_ptr != nullptr); ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); - if (tensor == nullptr) { + if (tensor == nullptr || tensor->buffer == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); return false; } @@ -1124,7 +1124,7 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp GGML_ASSERT(ctx_ptr != nullptr); ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); - if (tensor == nullptr) { + if (tensor == nullptr || tensor->buffer == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); return false; } @@ -1192,7 +1192,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< GGML_ASSERT(ctx_ptr != nullptr); ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); - if (tensor == nullptr) { + if (tensor == nullptr || tensor->buffer == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); return false; } @@ -1229,7 +1229,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); - if (src == nullptr || dst == nullptr) { + if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); return false; } From 749e0d27f0247337869f4698f59dd7fafba94326 Mon Sep 17 00:00:00 2001 From: kiwi <122582483+kiwi142857@users.noreply.github.com> Date: Fri, 25 Jul 2025 19:08:04 +0800 Subject: [PATCH 3/9] mtmd : fix 32-bit narrowing issue in export-lora and mtmd clip (#14503) * [fix] Fix 32-bit narrowing issue in export-lora and mtmd clip * Update export-lora.cpp * Update clip.cpp * Update export-lora.cpp * format: use space to replace tab --- tools/export-lora/export-lora.cpp | 2 +- tools/mtmd/clip.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/export-lora/export-lora.cpp b/tools/export-lora/export-lora.cpp index 24dc85cf27336..f038019b007b4 100644 --- a/tools/export-lora/export-lora.cpp +++ b/tools/export-lora/export-lora.cpp @@ -148,7 +148,7 @@ struct lora_merge_ctx { ctx_out = gguf_init_empty(); struct ggml_init_params params = { - /*.mem_size =*/ gguf_get_n_tensors(base_model.ctx_gguf)*ggml_tensor_overhead(), + /*.mem_size =*/ static_cast(gguf_get_n_tensors(base_model.ctx_gguf)*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index be191404cfc75..e8e3b0a013dbd 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2315,7 +2315,7 @@ struct clip_model_loader { // create data context struct ggml_init_params params = { - /*.mem_size =*/ (gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(), + /*.mem_size =*/ static_cast(gguf_get_n_tensors(ctx_gguf.get()) + 1) * ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; From c1dbea752a630169c104118fd0c82f3ffcb19c91 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 25 Jul 2025 14:28:06 +0300 Subject: [PATCH 4/9] context : restore preemptive sched reset when LLAMA_SET_ROWS=0 (#14870) ggml-ci --- src/llama-context.cpp | 14 +++++++++++++- src/llama-context.h | 4 ++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a91d157e29803..84f9ccab4ec2f 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -105,7 +105,7 @@ llama_context::llama_context( { const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - const bool supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false; + supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false; if (!supports_set_rows && !cparams.kv_unified) { LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); @@ -899,6 +899,12 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + if (!supports_set_rows) { + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1229,6 +1235,12 @@ int llama_context::decode(const llama_batch & batch_inp) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); + if (!supports_set_rows) { + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + } + return 0; } diff --git a/src/llama-context.h b/src/llama-context.h index fdbe61207e8ce..5c3a1c09886ea 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -287,6 +287,10 @@ struct llama_context { bool has_evaluated_once = false; + // env: LLAMA_SET_ROWS (temporary) + // ref: https://github.com/ggml-org/llama.cpp/pull/14285 + bool supports_set_rows = false; + // perf mutable int64_t t_start_us = 0; mutable int64_t t_load_us = 0; From e2b7621e7c265a6739225125cf9c534f471b3472 Mon Sep 17 00:00:00 2001 From: Oliver Simons Date: Fri, 25 Jul 2025 13:29:57 +0200 Subject: [PATCH 5/9] ggml : remove invalid portPos specifiers from dot files (#14838) Neither "g" nor "x" are valid portPos specifiers per the official [graphviz documents](https://graphviz.org/docs/attr-types/portPos/): > If a compass point is used, it must have the form "n","ne","e","se","s","sw","w","nw","c","_". I tested locally for it to fall back to default portPos specifier if an invalid portPos is specified. As a consequence, we can remove associated code. --- ggml/src/ggml.c | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5ae1c527df639..124cf3e8b6025 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6640,20 +6640,18 @@ static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgr static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node); struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent); - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", + fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", gparent0 ? (void *) gparent0 : (void *) parent, - gparent0 ? "g" : "x", gparent ? (void *) gparent : (void *) node, - gparent ? "g" : "x", gparent ? "empty" : "vee", gparent ? "dashed" : "solid", label); } static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", - (void *) parent, "x", - (void *) node, "x", + fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n", + (void *) parent, + (void *) node, label); } From e7fecba93416c8aaf343932b5e36dd5e208a815e Mon Sep 17 00:00:00 2001 From: wooksong Date: Fri, 25 Jul 2025 23:25:05 +0900 Subject: [PATCH 6/9] =?UTF-8?q?docs=20:=20update=20HOWTO=E2=80=91add?= =?UTF-8?q?=E2=80=91model.md=20for=20ModelBase=20and=20new=20model=20class?= =?UTF-8?q?es=20(#14874)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch updates the example in docs/development/HOWTO-add-model.md to reflect recent changes after `TextModel` and `MmprojModel` were introduced. It replaces the outdated `Model` base class with `TextModel` or `MmprojModel` and updates the registration example accordingly. Signed-off-by: Wook Song --- docs/development/HOWTO-add-model.md | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/docs/development/HOWTO-add-model.md b/docs/development/HOWTO-add-model.md index 51e0b0b20f58d..5989b873a611d 100644 --- a/docs/development/HOWTO-add-model.md +++ b/docs/development/HOWTO-add-model.md @@ -23,11 +23,19 @@ The convert script reads the model configuration, tokenizer, tensor names+data a The required steps to implement for an HF model are: -1. Define the model `Model.register` annotation in a new `Model` subclass, example: +1. Define the model `ModelBase.register` annotation in a new `TextModel` or `MmprojModel` subclass, example: ```python -@Model.register("MyModelForCausalLM") -class MyModel(Model): +@ModelBase.register("MyModelForCausalLM") +class MyModel(TextModel): + model_arch = gguf.MODEL_ARCH.MYMODEL +``` + +or + +```python +@ModelBase.register("MyModelForConditionalGeneration") +class MyModel(MmprojModel): model_arch = gguf.MODEL_ARCH.MYMODEL ``` @@ -75,9 +83,10 @@ block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = { `transformer.blocks.{bid}.norm_1` will be mapped to `blk.{bid}.attn_norm` in GGUF. Depending on the model configuration, tokenizer, code and tensors layout, you will have to override: -- `Model#set_gguf_parameters` -- `Model#set_vocab` -- `Model#write_tensors` +- `TextModel#set_gguf_parameters` +- `MmprojModel#set_gguf_parameters` +- `ModelBase#set_vocab` +- `ModelBase#modify_tensors` NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights. From ce111d39d666f4a3b6a561ad020f6feb8cc67790 Mon Sep 17 00:00:00 2001 From: lhez Date: Fri, 25 Jul 2025 08:12:13 -0700 Subject: [PATCH 7/9] opencl: add fused `rms_norm_mul` (#14841) * opencl: add fused `rms_norm` + `mul` * opencl: improve workgroup size for `rms_norm_mul` --- ggml/src/ggml-opencl/ggml-opencl.cpp | 163 ++++++++++++++++++++++- ggml/src/ggml-opencl/kernels/rms_norm.cl | 79 +++++++++++ 2 files changed, 240 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 63ac4a989b08b..c87a32383c873 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -333,6 +333,7 @@ struct ggml_backend_opencl_context { size_t max_alloc_size; bool fp16_support; bool has_vector_subgroup_broadcast; + bool disable_fusion; ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; @@ -411,7 +412,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick, kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; cl_kernel kernel_norm; - cl_kernel kernel_rms_norm; + cl_kernel kernel_rms_norm, kernel_rms_norm_mul; cl_kernel kernel_group_norm; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; @@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve backend_ctx->program_rms_norm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err)); GGML_LOG_CONT("."); } @@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err)); #endif // GGML_OPENCL_USE_ADRENO_KERNELS + backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; + dev_ctx->backend_ctx = backend_ctx.release(); return dev_ctx->backend_ctx; } @@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) { sync_with_other_backends(backend_ctx); } +static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + return false; + } + + // rms_norm assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + + return true; +} + +static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor); + static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]); + i++; + continue; + } + bool ok = ggml_cl_compute_forward(backend, node); if (!ok) { GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); @@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) { + GGML_ASSERT(mul_tensor); + GGML_ASSERT(rms_norm_tensor); + + // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm) + const ggml_tensor * src0 = rms_norm_tensor->src[0]; + const ggml_tensor * src1; + if (mul_tensor->src[0] == rms_norm_tensor) { + src1 = mul_tensor->src[1]; + } else if (mul_tensor->src[1] == rms_norm_tensor) { + src1 = mul_tensor->src[0]; + } else { + GGML_ASSERT(false && "Invalid args for rms_norm and mul"); + } + const ggml_tensor * dst = mul_tensor; + + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + float eps; + memcpy(&eps, rms_norm_tensor->op_params, sizeof(float)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + GGML_ASSERT(ne00 % 4 == 0); + + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + cl_kernel kernel = backend_ctx->kernel_rms_norm_mul; + + int nth = sgs; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + while (nth < ne00 && nth < max_workgroup_size) { + nth *= 2; + } + nth = MIN(nth, max_workgroup_size); + nth = MIN(nth, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl index 9d21f3398ec38..ecd053cb4c1ce 100644 --- a/ggml/src/ggml-opencl/kernels/rms_norm.cl +++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -94,3 +94,82 @@ kernel void kernel_rms_norm( } } } + +//------------------------------------------------------------------------------ +// rms_norm_mul +//------------------------------------------------------------------------------ +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, + float eps, + local float * sum +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11); + + float sumf = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += dot(x[i00], x[i00]); + } + sumf = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + float mean = sum[0]; + float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1); + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = (x[i00] * scale) * f[i00%(ne10/4)]; + } +} From 793c0d7f46384001738c337d7afa46b45ae32745 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 25 Jul 2025 10:47:39 -0600 Subject: [PATCH 8/9] metal: SSM_SCAN performance (#14743) * feat: Add s_off as a parameter in the args struct This may not be necessary, but it more closely mirrors the CUDA kernel Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart * perf: Parallelize mamba2 SSM_SCAN metal kernel over d_state This is a first attempt at optimizing the metal kernel. The changes here are: - Launch the kernel with a thread group of size d_state - Use simd groups and shared memory to do the summation for the y computation When tested with G4 tiny preview, this shows roughly a 3x speedup on prefill and 15% speedup on decode. Signed-off-by: Gabe Goodhart * fix: Update logic to correctly do the multi-layer parallel sum Signed-off-by: Gabe Goodhart * fix: Correctly size the shared memory bufer and assert expected size relationships Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart * refactor: Compute block offsets once rather than once per token Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart * feat: Use local variable for state recursion Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart * feat: Use a secondary simd_sum instead of a for loop Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart * feat: Add assertion and comment about relationship between simd size and num simd groups Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart * feat: Parallelize of d_state for mamba-1 Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart * feat: Parallel sum in SSM_CONV Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart * Revert "feat: Parallel sum in SSM_CONV" After discussion with @compilade, the size of the parallelism here is not worth the cost in complexity or overhead of the parallel for. https://github.com/ggml-org/llama.cpp/pull/14743#discussion_r2223395357 This reverts commit 16bc059660c1c59e566628201c0ca2c20c9f4bc3. Signed-off-by: Gabe Goodhart * refactor: Simplify shared memory sizing Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart Co-Authored-By: Georgi Gerganov --------- Signed-off-by: Gabe Goodhart Co-authored-by: Georgi Gerganov --- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal.m | 15 ++- ggml/src/ggml-metal/ggml-metal.metal | 182 ++++++++++++++++++++------ 3 files changed, 156 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index b7b3fc49af35d..8424464d8cadc 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -528,6 +528,7 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; + int64_t s_off; uint64_t nb01; uint64_t nb02; uint64_t nb03; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 1a9999325fe27..337f7985badf3 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node( /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(src1) * sizeof(float), /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, @@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; [encoder setBytes:&args length:sizeof(args) atIndex:8]; + // One shared memory bucket for each simd group in the threadgroup + // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + if (d_state >= 32) { + GGML_ASSERT((int64_t)(d_state / 32) <= 32); + const int64_t shmem_size = 32; + GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); + [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; + } + if (ne30 == 1) { // Mamba-2 - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } else { GGML_ASSERT(d_inner == 1); - [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } } break; case GGML_OP_RWKV_WKV6: diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f62b9ad548e69..99a453090f6b0 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t i1 = 0; const int64_t ir = tgpig.x; // current head const int64_t i3 = tgpig.y; // current seq @@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = i0 + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; + + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); + device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} + device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; - } + const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + s = state; + + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. + + // Computed for each thread + float sumf = state * C[i0]; - y[0] = sumf; + // Sum the threads in the simd group => simd sum + sumf = simd_sum(sumf); + + if (sgptg > 1) { + + // Once per simd group, place the group sum into the shared buffer + if (tiisg == 0) { + shared[sgitg] = sumf; + } + + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } + } + } else if (tiisg == 0) { + y[0] = sumf; + } // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -// TODO: optimize (e.g. by parallelizing over d_state) kernel void kernel_ssm_scan_f32_group( device const void * src0, device const void * src1, @@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq @@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = i0 + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; + + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); + device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} + device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); - float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * dA) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; + const float state = (s0 * dA) + (B[i0] * x_dt); + s = state; + + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. + + // Computed for each thread + float sumf = state * C[i0]; + + // Sum the threads in the simd group => simd sum + sumf = simd_sum(sumf); + + // Once per simd group, place the group sum into the shared buffer + if (tiisg == 0) { + shared[sgitg] = sumf; } - y[0] = sumf; + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } + } // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } kernel void kernel_rwkv_wkv6_f32( From c7f3169cd523140a288095f2d79befb20a0b73f4 Mon Sep 17 00:00:00 2001 From: Aaron Teo Date: Sat, 26 Jul 2025 01:09:03 +0800 Subject: [PATCH 9/9] ggml-cpu : disable GGML_NNPA by default due to instability (#14880) * docs: update s390x document for sentencepiece Signed-off-by: Aaron Teo (cherry picked from commit e086c5e3a7ab3463d8e0906efcfa39352db0a48d) * docs: update huggingface links + reword Signed-off-by: Aaron Teo (cherry picked from commit 8410b085ea8c46e22be38266147a1e94757ef108) * ggml-cpu: disable ggml-nnpa compile flag by default fixes #14877 Signed-off-by: Aaron Teo (cherry picked from commit 412f4c7c88894b8f55846b4719c76892a23cfe09) * docs: update s390x build docs to reflect nnpa disable Signed-off-by: Aaron Teo (cherry picked from commit c1eeae1d0c2edc74ab9fbeff2707b0d357cf0b4d) --------- Signed-off-by: Aaron Teo --- docs/build-s390x.md | 46 ++++++++++++++++++++++++++------ ggml/CMakeLists.txt | 2 +- ggml/src/ggml-cpu/CMakeLists.txt | 1 + 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/docs/build-s390x.md b/docs/build-s390x.md index 4c9ebb271cee2..4d5857753ae68 100644 --- a/docs/build-s390x.md +++ b/docs/build-s390x.md @@ -42,14 +42,14 @@ cmake --build build --config Release -j $(nproc) cmake --build build --config Release -j $(nproc) ``` -- By default, NNPA is enabled when available. To disable it (not recommended): +- By default, NNPA is disabled by default. To enable it: ```bash cmake -S . -B build \ -DCMAKE_BUILD_TYPE=Release \ -DGGML_BLAS=ON \ -DGGML_BLAS_VENDOR=OpenBLAS \ - -DGGML_NNPA=OFF + -DGGML_NNPA=ON cmake --build build --config Release -j $(nproc) ``` @@ -84,9 +84,9 @@ All models need to be converted to Big-Endian. You can achieve this in three cas ![File Type - gguf](https://img.shields.io/badge/File_Type-gguf-fff) - You can find popular models pre-converted and verified at [s390x Ready Models](https://huggingface.co/collections/taronaeo/s390x-ready-models-672765393af438d0ccb72a08). + You can find popular models pre-converted and verified at [s390x Verified Models](https://huggingface.co/collections/taronaeo/s390x-verified-models-672765393af438d0ccb72a08) or [s390x Runnable Models](https://huggingface.co/collections/taronaeo/s390x-runnable-models-686e951824198df12416017e). - These models have already been converted from `safetensors` to `GGUF Big-Endian` and their respective tokenizers verified to run correctly on IBM z15 and later system. + These models have already been converted from `safetensors` to `GGUF` Big-Endian and their respective tokenizers verified to run correctly on IBM z15 and later system. 2. **Convert safetensors model to GGUF Big-Endian directly (recommended)** @@ -94,6 +94,14 @@ All models need to be converted to Big-Endian. You can achieve this in three cas The model you are trying to convert must be in `safetensors` file format (for example [IBM Granite 3.3 2B](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct)). Make sure you have downloaded the model repository for this case. + Ensure that you have installed the required packages in advance + + ```bash + pip3 install -r requirements.txt + ``` + + Convert the `safetensors` model to `GGUF` + ```bash python3 convert_hf_to_gguf.py \ --outfile model-name-be.f16.gguf \ @@ -116,7 +124,7 @@ All models need to be converted to Big-Endian. You can achieve this in three cas ![File Type - gguf](https://img.shields.io/badge/File_Type-gguf-fff) - The model you are trying to convert must be in `gguf` file format (for example [IBM Granite 3.3 2B](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct-GGUF)). Make sure you have downloaded the model file for this case. + The model you are trying to convert must be in `gguf` file format (for example [IBM Granite 3.3 2B GGUF](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct-GGUF)). Make sure you have downloaded the model file for this case. ```bash python3 gguf-py/gguf/scripts/gguf_convert_endian.py model-name.f16.gguf BIG @@ -141,15 +149,15 @@ Only available in IBM z15 or later system with the `-DGGML_VXE=ON` (turned on by ### 2. NNPA Vector Intrinsics Acceleration -Only available in IBM z16 or later system with the `-DGGML_NNPA=ON` (turned on when available) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation. +Only available in IBM z16 or later system with the `-DGGML_NNPA=ON` (turned off by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation. ### 3. zDNN Accelerator -_Only available in IBM z16 or later system. No direction at the moment._ +_Only available in IBM z16 / LinuxONE 4 or later system. No support currently available._ ### 4. Spyre Accelerator -_No direction at the moment._ +_Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._ ## Performance Tuning @@ -189,6 +197,26 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl Answer: Please ensure that your GCC compiler is of minimum GCC 15.1.0 version, and have `binutils` updated to the latest version. If this does not fix the problem, kindly open an issue. +4. Failing to install the `sentencepiece` package using GCC 15+ + + Answer: The `sentencepiece` team are aware of this as seen in [this issue](https://github.com/google/sentencepiece/issues/1108). + + As a temporary workaround, please run the installation command with the following environment variables. + + ```bash + export CXXFLAGS="-include cstdint" + ``` + + For example, + + ```bash + CXXFLAGS="-include cstdint" pip3 install -r requirements.txt + ``` + +5. `-DGGML_NNPA=ON` generates gibberish output + + Answer: We are aware of this as detailed in [this issue](https://github.com/ggml-org/llama.cpp/issues/14877). Please either try reducing the number of threads, or disable the compile option using `-DGGML_NNPA=OFF`. + ## Getting Help on IBM Z & LinuxONE 1. **Bugs, Feature Requests** @@ -244,3 +272,5 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl - ✅ - acceleration available - 🚫 - acceleration unavailable, will still run using scalar implementation - ❓ - acceleration unknown, please contribute if you can test it yourself + +Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on July 25, 2025. diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 8ca1053cab320..20467c54da102 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -131,7 +131,7 @@ option(GGML_RVV "ggml: enable rvv" ON) option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_VXE "ggml: enable vxe" ON) -option(GGML_NNPA "ggml: enable nnpa" ON) +option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877 option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 2cc42d4b02af9..f188d1638dc5d 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -458,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -march=z16) elseif (${S390X_M} MATCHES "9175|9176") # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version. + # binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15. message(STATUS "z17 target") list(APPEND ARCH_FLAGS -march=z17) else()