Skip to content

Commit 633e743

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents 06e77d5 + 2bc9693 commit 633e743

File tree

11 files changed

+75
-52
lines changed

11 files changed

+75
-52
lines changed

.github/workflows/build-linux-cross.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ jobs:
291291
-DGGML_RVV=ON \
292292
-DGGML_RV_ZFH=ON \
293293
-DGGML_RV_ZICBOP=ON \
294+
-DGGML_RV_ZIHINTPAUSE=ON \
294295
-DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \
295296
-DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake
296297

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ range of hardware - locally and in the cloud.
6161
- Plain C/C++ implementation without any dependencies
6262
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
6363
- AVX, AVX2, AVX512 and AMX support for x86 architectures
64-
- RVV, ZVFH, ZFH and ZICBOP support for RISC-V architectures
64+
- RVV, ZVFH, ZFH, ZICBOP and ZIHINTPAUSE support for RISC-V architectures
6565
- 1.5-bit, 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
6666
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP and Moore Threads GPUs via MUSA)
6767
- Vulkan and SYCL backend support

docs/build-riscv64-spacemit.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ cmake -B build \
1919
-DGGML_RVV=ON \
2020
-DGGML_RV_ZFH=ON \
2121
-DGGML_RV_ZICBOP=ON \
22+
-DGGML_RV_ZIHINTPAUSE=ON \
2223
-DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \
2324
-DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \
2425
-DCMAKE_INSTALL_PREFIX=build/installed

ggml/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ option(GGML_RVV "ggml: enable rvv" ON)
168168
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
169169
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
170170
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
171+
option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON)
171172
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
172173
option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE})
173174

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
469469
if (GGML_RV_ZICBOP)
470470
string(APPEND MARCH_STR "_zicbop")
471471
endif()
472+
if (GGML_RV_ZIHINTPAUSE)
473+
string(APPEND MARCH_STR "_zihintpause")
474+
endif()
472475
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
473476
else()
474477
# Begin with the lowest baseline

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,15 @@ static inline void ggml_thread_cpu_relax(void) {
490490
static inline void ggml_thread_cpu_relax(void) {
491491
_mm_pause();
492492
}
493+
#elif defined(__riscv)
494+
static inline void ggml_thread_cpu_relax(void) {
495+
#ifdef __riscv_zihintpause
496+
__asm__ __volatile__ ("pause");
497+
#else
498+
/* Encoding of the pause instruction */
499+
__asm__ __volatile__ (".4byte 0x100000F");
500+
#endif
501+
}
493502
#else
494503
static inline void ggml_thread_cpu_relax(void) {;}
495504
#endif

ggml/src/ggml-cuda/solve_tri.cu

Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "solve_tri.cuh"
44

55
#define MAX_N_FAST 64
6-
#define MAX_K_FAST 32
76

87
// ======================
98
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
@@ -48,65 +47,58 @@ static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
4847
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
4948

5049
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
51-
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
5250

5351
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
5452

5553
#pragma unroll
5654
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
57-
int i0 = i + offset;
55+
const int i0 = i + offset;
5856
if (i0 < n * n) {
5957
sA[i0] = A_batch[i0];
6058
}
6159
}
6260

63-
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
61+
__syncthreads();
6462

65-
#pragma unroll
66-
for (int i = 0; i < rows_per_warp; i++) {
67-
const int i0 = lane + i * WARP_SIZE;
68-
if (i0 < n) {
69-
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
70-
}
71-
}
63+
float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f;
64+
float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f;
7265

73-
__syncthreads();
66+
const int half = WARP_SIZE;
67+
const int nrows_low = (n < half) ? n : half;
7468

7569
#pragma unroll
76-
for (int row = 0; row < n; ++row) {
70+
for (int row = 0; row < nrows_low; ++row) {
7771
float sum = 0.0f;
78-
79-
{
80-
int j = lane;
81-
if (j < row) {
82-
sum += sA[row * n + j] * sXt[col_idx * n + j];
83-
}
72+
if (lane < row) {
73+
sum += sA[row * n + lane] * x_low;
8474
}
85-
if (row >= WARP_SIZE) {
86-
int j = WARP_SIZE + lane;
87-
if (j < row) {
88-
sum += sA[row * n + j] * sXt[col_idx * n + j];
89-
}
75+
sum = warp_reduce_sum(sum);
76+
77+
if (lane == row) {
78+
x_low = (x_low - sum) / sA[row * n + row];
9079
}
80+
}
9181

82+
#pragma unroll
83+
for (int row = half; row < n; ++row) {
84+
float sum = sA[row * n + lane] * x_low;
85+
const int j = half + lane;
86+
if (j < row) {
87+
sum += sA[row * n + j] * x_high;
88+
}
9289
sum = warp_reduce_sum(sum);
9390

94-
if (lane == 0) {
95-
const float b_val = sXt[col_idx * n + row];
96-
const float a_diag = sA[row * n + row];
97-
// no safeguards for division by zero because that indicates corrupt
98-
// data anyway
99-
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
91+
if (lane == row - half) {
92+
x_high = (x_high - sum) / sA[row * n + row];
10093
}
10194
}
10295

103-
__syncthreads();
104-
10596
#pragma unroll
106-
for (int i = 0; i < rows_per_warp; i++) {
107-
const int i0 = lane + i * WARP_SIZE;
108-
if (i0 < n) {
109-
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
97+
for (int rr = 0; rr < 2; ++rr) {
98+
const int row = rr * WARP_SIZE + lane;
99+
if (row < n) {
100+
const float val = (row < half) ? x_low : x_high;
101+
X_batch[row * k + col_idx] = val;
110102
}
111103
}
112104
}

tools/server/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,8 @@ By default, this value is set to `0`, meaning no tokens are kept. Use `-1` to re
495495

496496
`n_cmpl`: Number of completions to generate from the current prompt. If input has multiple prompts, the output will have N prompts times `n_cmpl` entries.
497497

498+
`n_cache_reuse`: Min chunk size to attempt reusing from the cache via KV shifting. For more info, see `--cache-reuse` arg. Default: `0`, which is disabled.
499+
498500
`stream`: Allows receiving each predicted token in real-time instead of waiting for the completion to finish (uses a different response format). To enable this, set to `true`.
499501

500502
`stop`: Specify a JSON array of stopping strings.

tools/server/server-context.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,8 +1880,18 @@ struct server_context_impl {
18801880
n_past = std::min(n_past, slot.alora_invocation_start - 1);
18811881
}
18821882

1883+
const auto n_cache_reuse = slot.task->params.n_cache_reuse;
1884+
1885+
const bool can_cache_reuse =
1886+
llama_memory_can_shift(llama_get_memory(ctx)) &&
1887+
!slot.prompt.tokens.has_mtmd;
1888+
1889+
if (!can_cache_reuse && n_cache_reuse > 0) {
1890+
SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse);
1891+
}
1892+
18831893
// reuse chunks from the cached prompt by shifting their KV cache in the new position
1884-
if (params_base.n_cache_reuse > 0) {
1894+
if (can_cache_reuse && n_cache_reuse > 0) {
18851895
GGML_ASSERT(!slot.prompt.tokens.has_mtmd);
18861896

18871897
size_t head_c = n_past; // cache
@@ -1892,7 +1902,7 @@ struct server_context_impl {
18921902
GGML_ABORT("not supported by multimodal");
18931903
}
18941904

1895-
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", params_base.n_cache_reuse, n_past);
1905+
SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past);
18961906

18971907
while (head_c < slot.prompt.tokens.size() &&
18981908
head_p < input_tokens.size()) {
@@ -1901,11 +1911,10 @@ struct server_context_impl {
19011911
while (head_c + n_match < slot.prompt.tokens.size() &&
19021912
head_p + n_match < input_tokens.size() &&
19031913
slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) {
1904-
19051914
n_match++;
19061915
}
19071916

1908-
if (n_match >= (size_t) params_base.n_cache_reuse) {
1917+
if (n_match >= (size_t) n_cache_reuse) {
19091918
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
19101919
//for (size_t i = head_p; i < head_p + n_match; i++) {
19111920
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());

tools/server/server-task.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,12 @@ task_params server_task::params_from_json_cmpl(
155155

156156
// Sampling parameter defaults are loaded from the global server context (but individual requests can still them)
157157
task_params defaults;
158-
defaults.sampling = params_base.sampling;
159-
defaults.speculative = params_base.speculative;
160-
defaults.n_keep = params_base.n_keep;
161-
defaults.n_predict = params_base.n_predict;
162-
defaults.antiprompt = params_base.antiprompt;
158+
defaults.sampling = params_base.sampling;
159+
defaults.speculative = params_base.speculative;
160+
defaults.n_keep = params_base.n_keep;
161+
defaults.n_predict = params_base.n_predict;
162+
defaults.n_cache_reuse = params_base.n_cache_reuse;
163+
defaults.antiprompt = params_base.antiprompt;
163164

164165
// enabling this will output extra debug information in the HTTP responses from the server
165166
params.verbose = params_base.verbosity > 9;
@@ -176,6 +177,7 @@ task_params server_task::params_from_json_cmpl(
176177
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
177178
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
178179
params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1));
180+
params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse);
179181
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
180182
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
181183
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());

0 commit comments

Comments
 (0)