Skip to content

Commit e62d5be

Browse files
authored
Merge branch 'ggerganov:master' into master
2 parents c2deb89 + b0cefea commit e62d5be

File tree

3 files changed

+34
-23
lines changed

3 files changed

+34
-23
lines changed

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ struct common_params {
178178
float yarn_beta_fast = 32.0f; // YaRN low correction dim
179179
float yarn_beta_slow = 1.0f; // YaRN high correction dim
180180
int32_t yarn_orig_ctx = 0; // YaRN original context length
181-
float defrag_thold = -1.0f; // KV cache defragmentation threshold
181+
float defrag_thold = 0.1f; // KV cache defragmentation threshold
182182

183183
struct cpu_params cpuparams;
184184
struct cpu_params cpuparams_batch;

examples/server/README.md

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ The project is under active development, and we are [looking for feedback and co
3939
| `--cpu-strict-batch <0\|1>` | use strict CPU placement (default: same as --cpu-strict) |
4040
| `--prio-batch N` | set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: 0)<br/> |
4141
| `--poll-batch <0\|1>` | use polling to wait for work (default: same as --poll) |
42-
| `-c, --ctx-size N` | size of the prompt context (default: 0, 0 = loaded from model)<br/>(env: LLAMA_ARG_CTX_SIZE) |
42+
| `-c, --ctx-size N` | size of the prompt context (default: 4096, 0 = loaded from model)<br/>(env: LLAMA_ARG_CTX_SIZE) |
4343
| `-n, --predict, --n-predict N` | number of tokens to predict (default: -1, -1 = infinity, -2 = until context filled)<br/>(env: LLAMA_ARG_N_PREDICT) |
4444
| `-b, --batch-size N` | logical maximum batch size (default: 2048)<br/>(env: LLAMA_ARG_BATCH) |
4545
| `-ub, --ubatch-size N` | physical maximum batch size (default: 512)<br/>(env: LLAMA_ARG_UBATCH) |
@@ -64,7 +64,7 @@ The project is under active development, and we are [looking for feedback and co
6464
| `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
6565
| `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
6666
| `-ctv, --cache-type-v TYPE` | KV cache data type for V (default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
67-
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: -1.0, < 0 - disabled)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
67+
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: 0.1, < 0 - disabled)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
6868
| `-np, --parallel N` | number of parallel sequences to decode (default: 1)<br/>(env: LLAMA_ARG_N_PARALLEL) |
6969
| `--mlock` | force system to keep model in RAM rather than swapping or compressing<br/>(env: LLAMA_ARG_MLOCK) |
7070
| `--no-mmap` | do not memory-map model (slower load but may reduce pageouts if not using mlock)<br/>(env: LLAMA_ARG_NO_MMAP) |
@@ -99,25 +99,27 @@ The project is under active development, and we are [looking for feedback and co
9999

100100
| Argument | Explanation |
101101
| -------- | ----------- |
102-
| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: top_k;typ_p;top_p;min_p;temperature) |
102+
| `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'<br/>(default: dry;top_k;typ_p;top_p;min_p;xtc;temperature) |
103103
| `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) |
104-
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) |
104+
| `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: dkypmxt) |
105105
| `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) |
106106
| `--penalize-nl` | penalize newline tokens (default: false) |
107107
| `--temp N` | temperature (default: 0.8) |
108108
| `--top-k N` | top-k sampling (default: 40, 0 = disabled) |
109109
| `--top-p N` | top-p sampling (default: 0.9, 1.0 = disabled) |
110110
| `--min-p N` | min-p sampling (default: 0.1, 0.0 = disabled) |
111+
| `--xtc-probability N` | xtc probability (default: 0.0, 0.0 = disabled) |
112+
| `--xtc-threshold N` | xtc threshold (default: 0.1, 1.0 = disabled) |
111113
| `--typical N` | locally typical sampling, parameter p (default: 1.0, 1.0 = disabled) |
112114
| `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) |
113115
| `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.0, 1.0 = disabled) |
114116
| `--presence-penalty N` | repeat alpha presence penalty (default: 0.0, 0.0 = disabled) |
115117
| `--frequency-penalty N` | repeat alpha frequency penalty (default: 0.0, 0.0 = disabled) |
116-
| `--dry-multiplier N` | DRY sampling multiplier (default: 0.0, 0.0 = disabled) |
117-
| `--dry-base N` | DRY sampling base value (default: 1.75) |
118-
| `--dry-allowed-length N` | allowed length for DRY sampling (default: 2) |
119-
| `--dry-penalty-last-n N` | DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) |
120-
| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers (`['\n', ':', '"', '*']`) in the process; use `"none"` to not use any sequence breakers
118+
| `--dry-multiplier N` | set DRY sampling multiplier (default: 0.0, 0.0 = disabled) |
119+
| `--dry-base N` | set DRY sampling base value (default: 1.75) |
120+
| `--dry-allowed-length N` | set allowed length for DRY sampling (default: 2) |
121+
| `--dry-penalty-last-n N` | set DRY penalty for the last n tokens (default: -1, 0 = disable, -1 = context size) |
122+
| `--dry-sequence-breaker STRING` | add sequence breaker for DRY sampling, clearing out default breakers ('\n', ':', '"', '*') in the process; use "none" to not use any sequence breakers<br/> |
121123
| `--dynatemp-range N` | dynamic temperature range (default: 0.0, 0.0 = disabled) |
122124
| `--dynatemp-exp N` | dynamic temperature exponent (default: 1.0) |
123125
| `--mirostat N` | use Mirostat sampling.<br/>Top K, Nucleus and Locally Typical samplers are ignored if used.<br/>(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) |

ggml/src/ggml-metal.metal

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext(
29422942
half smax = -INFINITY;
29432943

29442944
// load the mask in shared memory
2945+
#pragma unroll(Q)
29452946
for (short j = 0; j < Q; ++j) {
29462947
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
29472948

@@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext(
29682969
// we can read directly from global memory
29692970
device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
29702971

2971-
#pragma unroll
2972+
#pragma unroll(D8)
29722973
for (short i = 0; i < D8; ++i) {
29732974
k8x8_t mk;
29742975
simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
@@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext(
29892990

29902991
simdgroup_barrier(mem_flags::mem_threadgroup);
29912992

2992-
#pragma unroll
2993+
#pragma unroll(4)
29932994
for (short k = 0; k < 4; ++k) {
29942995
k8x8_t mk;
29952996

@@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext(
30673068
s8x8_t mm;
30683069
simdgroup_load(mm, ss + 2*C, TS, 0, false);
30693070

3070-
#pragma unroll
3071+
#pragma unroll(D8)
30713072
for (short i = 0; i < D8; ++i) {
30723073
simdgroup_multiply(lo[i], mm, lo[i]);
30733074
}
@@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext(
30823083
if (is_same<vd4x4_t, v4x4_t>::value) {
30833084
// we can read directly from global memory
30843085
device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3085-
#pragma unroll
3086+
3087+
#pragma unroll(D8)
30863088
for (short i = 0; i < D8; ++i) {
30873089
v8x8_t mv;
30883090
simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
@@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext(
31033105

31043106
simdgroup_barrier(mem_flags::mem_threadgroup);
31053107

3106-
#pragma unroll
3108+
#pragma unroll(4)
31073109
for (short k = 0; k < 4; ++k) {
31083110
v8x8_t mv;
31093111

@@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext(
31963198
simdgroup_load(ms0, ss + 2*C, TS, 0, false);
31973199
simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false);
31983200

3201+
#pragma unroll(D8)
31993202
for (short i = 0; i < D8; ++i) {
32003203
o8x8_t t;
32013204

@@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec(
34133416
// load the queries from shared memory into local memory
34143417
q4x4_t mq[D16/NL];
34153418

3419+
#pragma unroll(D16/NL)
34163420
for (short ii = 0; ii < D16; ii += NL) {
34173421
mq[ii/NL] = sq4x4[ii + tx];
34183422
}
@@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec(
34543458

34553459
device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
34563460

3457-
#pragma unroll
3461+
#pragma unroll(D16/NL)
34583462
for (short ii = 0; ii < D16; ii += NL) {
34593463
const short i = ii + tx;
34603464

34613465
k4x4_t mk;
34623466
deq_k(pk + i/nl_k, i%nl_k, mk);
34633467

3464-
mqka[0] += dot(mq[ii/NL][0], mk[0]);
3465-
mqka[1] += dot(mq[ii/NL][1], mk[1]);
3466-
mqka[2] += dot(mq[ii/NL][2], mk[2]);
3467-
mqka[3] += dot(mq[ii/NL][3], mk[3]);
3468+
// note: this is less precise than the version below
3469+
//mqka[0] += dot(mq[ii/NL][0], mk[0]);
3470+
//mqka[1] += dot(mq[ii/NL][1], mk[1]);
3471+
//mqka[2] += dot(mq[ii/NL][2], mk[2]);
3472+
//mqka[3] += dot(mq[ii/NL][3], mk[3]);
3473+
3474+
mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]);
3475+
mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]);
3476+
mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]);
3477+
mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]);
34683478
}
34693479

34703480
qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3];
@@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec(
35133523
ss[tiisg] = vs;
35143524

35153525
// O = diag(ms)*O
3516-
#pragma unroll
3526+
#pragma unroll(D16/NL)
35173527
for (short ii = 0; ii < D16; ii += NL) {
35183528
lo[ii/NL] *= ms;
35193529
}
@@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec(
35233533

35243534
// O = O + (Q*K^T)*V
35253535
{
3526-
#pragma unroll
35273536
for (short cc = 0; cc < C/4; ++cc) {
35283537
device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
35293538

35303539
const s4x4_t ms(ss[4*cc + ty]);
35313540

3532-
#pragma unroll
3541+
#pragma unroll(D16/NL)
35333542
for (short ii = 0; ii < D16; ii += NL) {
35343543
const short i = ii + tx;
35353544

0 commit comments

Comments
 (0)