Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 38 additions & 8 deletions docs/build-s390x.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ cmake --build build --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc)
```

- By default, NNPA is enabled when available. To disable it (not recommended):
- By default, NNPA is disabled by default. To enable it:

```bash
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS \
-DGGML_NNPA=OFF
-DGGML_NNPA=ON

cmake --build build --config Release -j $(nproc)
```
Expand Down Expand Up @@ -84,16 +84,24 @@ All models need to be converted to Big-Endian. You can achieve this in three cas

![File Type - gguf](https://img.shields.io/badge/File_Type-gguf-fff)

You can find popular models pre-converted and verified at [s390x Ready Models](https://huggingface.co/collections/taronaeo/s390x-ready-models-672765393af438d0ccb72a08).
You can find popular models pre-converted and verified at [s390x Verified Models](https://huggingface.co/collections/taronaeo/s390x-verified-models-672765393af438d0ccb72a08) or [s390x Runnable Models](https://huggingface.co/collections/taronaeo/s390x-runnable-models-686e951824198df12416017e).

These models have already been converted from `safetensors` to `GGUF Big-Endian` and their respective tokenizers verified to run correctly on IBM z15 and later system.
These models have already been converted from `safetensors` to `GGUF` Big-Endian and their respective tokenizers verified to run correctly on IBM z15 and later system.

2. **Convert safetensors model to GGUF Big-Endian directly (recommended)**

![File Type - safetensors](https://img.shields.io/badge/File_Type-safetensors-da1e28)

The model you are trying to convert must be in `safetensors` file format (for example [IBM Granite 3.3 2B](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct)). Make sure you have downloaded the model repository for this case.

Ensure that you have installed the required packages in advance

```bash
pip3 install -r requirements.txt
```

Convert the `safetensors` model to `GGUF`

```bash
python3 convert_hf_to_gguf.py \
--outfile model-name-be.f16.gguf \
Expand All @@ -116,7 +124,7 @@ All models need to be converted to Big-Endian. You can achieve this in three cas

![File Type - gguf](https://img.shields.io/badge/File_Type-gguf-fff)

The model you are trying to convert must be in `gguf` file format (for example [IBM Granite 3.3 2B](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct-GGUF)). Make sure you have downloaded the model file for this case.
The model you are trying to convert must be in `gguf` file format (for example [IBM Granite 3.3 2B GGUF](https://huggingface.co/ibm-granite/granite-3.3-2b-instruct-GGUF)). Make sure you have downloaded the model file for this case.

```bash
python3 gguf-py/gguf/scripts/gguf_convert_endian.py model-name.f16.gguf BIG
Expand All @@ -141,15 +149,15 @@ Only available in IBM z15 or later system with the `-DGGML_VXE=ON` (turned on by

### 2. NNPA Vector Intrinsics Acceleration

Only available in IBM z16 or later system with the `-DGGML_NNPA=ON` (turned on when available) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation.
Only available in IBM z16 or later system with the `-DGGML_NNPA=ON` (turned off by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation.

### 3. zDNN Accelerator

_Only available in IBM z16 or later system. No direction at the moment._
_Only available in IBM z16 / LinuxONE 4 or later system. No support currently available._

### 4. Spyre Accelerator

_No direction at the moment._
_Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._

## Performance Tuning

Expand Down Expand Up @@ -189,6 +197,26 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl

Answer: Please ensure that your GCC compiler is of minimum GCC 15.1.0 version, and have `binutils` updated to the latest version. If this does not fix the problem, kindly open an issue.

4. Failing to install the `sentencepiece` package using GCC 15+

Answer: The `sentencepiece` team are aware of this as seen in [this issue](https://github.com/google/sentencepiece/issues/1108).

As a temporary workaround, please run the installation command with the following environment variables.

```bash
export CXXFLAGS="-include cstdint"
```

For example,

```bash
CXXFLAGS="-include cstdint" pip3 install -r requirements.txt
```

5. `-DGGML_NNPA=ON` generates gibberish output

Answer: We are aware of this as detailed in [this issue](https://github.com/ggml-org/llama.cpp/issues/14877). Please either try reducing the number of threads, or disable the compile option using `-DGGML_NNPA=OFF`.

## Getting Help on IBM Z & LinuxONE

1. **Bugs, Feature Requests**
Expand Down Expand Up @@ -244,3 +272,5 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
- ✅ - acceleration available
- 🚫 - acceleration unavailable, will still run using scalar implementation
- ❓ - acceleration unknown, please contribute if you can test it yourself

Last Updated by **Aaron Teo ([email protected])** on July 25, 2025.
21 changes: 15 additions & 6 deletions docs/development/HOWTO-add-model.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,19 @@ The convert script reads the model configuration, tokenizer, tensor names+data a

The required steps to implement for an HF model are:

1. Define the model `Model.register` annotation in a new `Model` subclass, example:
1. Define the model `ModelBase.register` annotation in a new `TextModel` or `MmprojModel` subclass, example:

```python
@Model.register("MyModelForCausalLM")
class MyModel(Model):
@ModelBase.register("MyModelForCausalLM")
class MyModel(TextModel):
model_arch = gguf.MODEL_ARCH.MYMODEL
```

or

```python
@ModelBase.register("MyModelForConditionalGeneration")
class MyModel(MmprojModel):
model_arch = gguf.MODEL_ARCH.MYMODEL
```

Expand Down Expand Up @@ -75,9 +83,10 @@ block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
`transformer.blocks.{bid}.norm_1` will be mapped to `blk.{bid}.attn_norm` in GGUF.

Depending on the model configuration, tokenizer, code and tensors layout, you will have to override:
- `Model#set_gguf_parameters`
- `Model#set_vocab`
- `Model#write_tensors`
- `TextModel#set_gguf_parameters`
- `MmprojModel#set_gguf_parameters`
- `ModelBase#set_vocab`
- `ModelBase#modify_tensors`

NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.

Expand Down
2 changes: 1 addition & 1 deletion ggml/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_NNPA "ggml: enable nnpa" ON)
option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877

option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")
Expand Down
13 changes: 8 additions & 5 deletions ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ struct ggml_backend_sched {
// pipeline parallelism support
int n_copies;
int cur_copy;
int next_copy;
ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES];
struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS];
int n_graph_inputs;
Expand Down Expand Up @@ -1433,8 +1434,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}
}

sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies;

return GGML_STATUS_SUCCESS;
}

Expand Down Expand Up @@ -1535,10 +1534,10 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);

ggml_backend_sched_split_graph(sched, measure_graph);

ggml_backend_sched_synchronize(sched);

ggml_backend_sched_split_graph(sched, measure_graph);

if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) {
return false;
}
Expand All @@ -1550,6 +1549,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *

bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
GGML_ASSERT(!sched->is_alloc);

sched->cur_copy = sched->next_copy;
sched->next_copy = (sched->next_copy + 1) % sched->n_copies;

ggml_backend_sched_split_graph(sched, graph);

Expand Down Expand Up @@ -1590,7 +1593,7 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
// if the graph is not already allocated, always use copy 0 after a synchronization
// this ensures that during generation the same copy is used every time,
// which avoids changes in the graph that could cause CUDA or other graphs to be disabled
sched->cur_copy = 0;
sched->next_copy = 0;
}
}

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND ARCH_FLAGS -march=z16)
elseif (${S390X_M} MATCHES "9175|9176")
# NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version.
# binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15.
message(STATUS "z17 target")
list(APPEND ARCH_FLAGS -march=z17)
else()
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,7 @@ typedef struct {
int64_t n_group;
int64_t n_seq_tokens;
int64_t n_seqs;
int64_t s_off;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
Expand Down
15 changes: 13 additions & 2 deletions ggml/src/ggml-metal/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -3141,6 +3141,7 @@ static int ggml_metal_encode_node(
/*.n_group =*/ n_group,
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(src1) * sizeof(float),
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
Expand Down Expand Up @@ -3169,12 +3170,22 @@ static int ggml_metal_encode_node(
[encoder setBuffer:id_dst offset:offs_dst atIndex:7];
[encoder setBytes:&args length:sizeof(args) atIndex:8];

// One shared memory bucket for each simd group in the threadgroup
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
if (d_state >= 32) {
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
const int64_t shmem_size = 32;
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
}

if (ne30 == 1) {
// Mamba-2
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
} else {
GGML_ASSERT(d_inner == 1);
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)];
}
} break;
case GGML_OP_RWKV_WKV6:
Expand Down
Loading