Skip to content

Commit 40b9ba1

Browse files
Merge branch 'ggerganov:master' into master
2 parents 770e435 + 098f6d7 commit 40b9ba1

File tree

14 files changed

+466
-317
lines changed

14 files changed

+466
-317
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -809,9 +809,9 @@ if (LLAMA_CCACHE)
809809
if (LLAMA_CCACHE_FOUND)
810810
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE ccache)
811811
set(ENV{CCACHE_SLOPPINESS} time_macros)
812-
message(STATUS "Using ccache")
812+
message(STATUS "ccache found, compilation results will be cached. Disable with LLAMA_CCACHE=OFF.")
813813
else()
814-
message(STATUS "Warning: ccache not found - consider installing it or use LLAMA_CCACHE=OFF")
814+
message(STATUS "Warning: ccache not found - consider installing it for faster compilation or disable this warning with LLAMA_CCACHE=OFF")
815815
endif ()
816816
endif()
817817

Makefile

Lines changed: 119 additions & 50 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 31 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
[Roadmap](https://github.com/users/ggerganov/projects/7) / [Project status](https://github.com/ggerganov/llama.cpp/discussions/3471) / [Manifesto](https://github.com/ggerganov/llama.cpp/discussions/205) / [ggml](https://github.com/ggerganov/ggml)
88

9-
Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
9+
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
1010

1111
### Hot topics
1212

@@ -58,30 +58,35 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
5858

5959
## Description
6060

61-
The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quantization on a MacBook
61+
The main goal of `llama.cpp` is to enable LLM inference with minimal setup and state-of-the-art performance on a wide
62+
variety of hardware - locally and in the cloud.
6263

63-
- Plain C/C++ implementation without dependencies
64-
- Apple silicon first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
64+
- Plain C/C++ implementation without any dependencies
65+
- Apple silicon is a first-class citizen - optimized via ARM NEON, Accelerate and Metal frameworks
6566
- AVX, AVX2 and AVX512 support for x86 architectures
66-
- Mixed F16 / F32 precision
67-
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit and 8-bit integer quantization support
68-
- CUDA, Metal, OpenCL, SYCL GPU backend support
67+
- 2-bit, 3-bit, 4-bit, 5-bit, 6-bit, and 8-bit integer quantization for faster inference and reduced memory use
68+
- Custom CUDA kernels for running LLMs on NVIDIA GPUs (support for AMD GPUs via HIP)
69+
- Vulkan, SYCL, and (partial) OpenCL backend support
70+
- CPU+GPU hybrid inference to partially accelerate models larger than the total VRAM capacity
6971

70-
The original implementation of `llama.cpp` was [hacked in an evening](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022).
71-
Since then, the project has improved significantly thanks to many contributions. This project is mainly for educational purposes and serves
72-
as the main playground for developing new features for the [ggml](https://github.com/ggerganov/ggml) library.
72+
Since its [inception](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022), the project has
73+
improved significantly thanks to many contributions. It is the main playground for developing new features for the
74+
[ggml](https://github.com/ggerganov/ggml) library.
7375

7476
**Supported platforms:**
7577

7678
- [X] Mac OS
7779
- [X] Linux
7880
- [X] Windows (via CMake)
7981
- [X] Docker
82+
- [X] FreeBSD
8083

8184
**Supported models:**
8285

8386
- [X] LLaMA 🦙
8487
- [x] LLaMA 2 🦙🦙
88+
- [X] [Mistral AI v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
89+
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
8590
- [X] Falcon
8691
- [X] [Alpaca](https://github.com/ggerganov/llama.cpp#instruction-mode-with-alpaca)
8792
- [X] [GPT4All](https://github.com/ggerganov/llama.cpp#using-gpt4all)
@@ -95,7 +100,6 @@ as the main playground for developing new features for the [ggml](https://github
95100
- [X] [Baichuan 1 & 2](https://huggingface.co/models?search=baichuan-inc/Baichuan) + [derivations](https://huggingface.co/hiyouga/baichuan-7b-sft)
96101
- [X] [Aquila 1 & 2](https://huggingface.co/models?search=BAAI/Aquila)
97102
- [X] [Starcoder models](https://github.com/ggerganov/llama.cpp/pull/3187)
98-
- [X] [Mistral AI v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
99103
- [X] [Refact](https://huggingface.co/smallcloudai/Refact-1_6B-fim)
100104
- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410)
101105
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
@@ -104,14 +108,14 @@ as the main playground for developing new features for the [ggml](https://github
104108
- [X] [StableLM-3b-4e1t](https://github.com/ggerganov/llama.cpp/pull/3586)
105109
- [x] [Deepseek models](https://huggingface.co/models?search=deepseek-ai/deepseek)
106110
- [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
107-
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
108111
- [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
109112
- [x] [GPT-2](https://huggingface.co/gpt2)
113+
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
110114

111115
**Multimodal models:**
112116

113-
- [x] [Llava 1.5 models](https://huggingface.co/collections/liuhaotian/llava-15-653aac15d994e992e2677a7e)
114-
- [x] [Bakllava](https://huggingface.co/models?search=SkunkworksAI/Bakllava)
117+
- [x] [LLaVA 1.5 models](https://huggingface.co/collections/liuhaotian/llava-15-653aac15d994e992e2677a7e)
118+
- [x] [BakLLaVA](https://huggingface.co/models?search=SkunkworksAI/Bakllava)
115119
- [x] [Obsidian](https://huggingface.co/NousResearch/Obsidian-3B-V0.5)
116120
- [x] [ShareGPT4V](https://huggingface.co/models?search=Lin-Chen/ShareGPT4V)
117121
- [x] [MobileVLM 1.7B/3B models](https://huggingface.co/models?search=mobileVLM)
@@ -136,14 +140,22 @@ as the main playground for developing new features for the [ggml](https://github
136140

137141
**UI:**
138142

143+
Unless otherwise noted these projects are open-source with permissive licensing:
144+
145+
- [iohub/collama](https://github.com/iohub/coLLaMA)
146+
- [janhq/jan](https://github.com/janhq/jan) (AGPL)
139147
- [nat/openplayground](https://github.com/nat/openplayground)
140-
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui)
141-
- [withcatai/catai](https://github.com/withcatai/catai)
142-
- [semperai/amica](https://github.com/semperai/amica)
148+
- [LMStudio](https://lmstudio.ai/) (proprietary)
149+
- [LostRuins/koboldcpp](https://github.com/LostRuins/koboldcpp) (AGPL)
150+
- [Mozilla-Ocho/llamafile](https://github.com/Mozilla-Ocho/llamafile)
151+
- [nomic-ai/gpt4all](https://github.com/nomic-ai/gpt4all)
152+
- [ollama/ollama](https://github.com/ollama/ollama)
153+
- [oobabooga/text-generation-webui](https://github.com/oobabooga/text-generation-webui) (AGPL)
143154
- [psugihara/FreeChat](https://github.com/psugihara/FreeChat)
144155
- [ptsochantaris/emeltal](https://github.com/ptsochantaris/emeltal)
145-
- [iohub/collama](https://github.com/iohub/coLLaMA)
146-
- [pythops/tenere](https://github.com/pythops/tenere)
156+
- [pythops/tenere](https://github.com/pythops/tenere) (AGPL)
157+
- [semperai/amica](https://github.com/semperai/amica)
158+
- [withcatai/catai](https://github.com/withcatai/catai)
147159

148160
---
149161

common/common.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
399399
break;
400400
}
401401
sparams.penalty_present = std::stof(argv[i]);
402+
} else if (arg == "--dynatemp-range") {
403+
if (++i >= argc) {
404+
invalid_param = true;
405+
break;
406+
}
407+
sparams.dynatemp_range = std::stof(argv[i]);
408+
} else if (arg == "--dynatemp-exp") {
409+
if (++i >= argc) {
410+
invalid_param = true;
411+
break;
412+
}
413+
sparams.dynatemp_exponent = std::stof(argv[i]);
402414
} else if (arg == "--mirostat") {
403415
if (++i >= argc) {
404416
invalid_param = true;
@@ -942,6 +954,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
942954
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
943955
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
944956
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
957+
printf(" --dynatemp-range N dynamic temperature range (default: %.1f, 0.0 = disabled)\n", (double)sparams.dynatemp_range);
958+
printf(" --dynatemp-exp N dynamic temperature exponent (default: %.1f)\n", (double)sparams.dynatemp_exponent);
945959
printf(" --mirostat N use Mirostat sampling.\n");
946960
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
947961
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);

convert-hf-to-gguf.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1416,8 +1416,32 @@ def set_vocab(self):
14161416
self.gguf_writer.add_add_space_prefix(add_prefix)
14171417

14181418
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
1419+
old_eos = special_vocab.special_token_ids["eos"]
1420+
if "chat" in os.path.basename(self.dir_model.absolute()):
1421+
# For the chat model, we replace the eos with '<|im_end|>'.
1422+
special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
1423+
print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
1424+
in chat mode so that the conversation can end normally.")
1425+
14191426
special_vocab.add_to_gguf(self.gguf_writer)
14201427

1428+
def _try_get_sft_eos(self, tokenizer):
1429+
unused_145_list = tokenizer.encode('[UNUSED_TOKEN_145]')
1430+
im_end_list = tokenizer.encode('<|im_end|>')
1431+
assert (len(unused_145_list) == 1) ^ (len(im_end_list) == 1)
1432+
if len(unused_145_list) == 1:
1433+
eos_token = unused_145_list[0]
1434+
if len(im_end_list) == 1:
1435+
eos_token = im_end_list[0]
1436+
return eos_token
1437+
1438+
def _hf_permute_qk(self, weights, n_head: int, n_head_kv: int):
1439+
if n_head_kv is not None and n_head != n_head_kv:
1440+
n_head = n_head_kv
1441+
return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
1442+
.swapaxes(1, 2)
1443+
.reshape(weights.shape))
1444+
14211445
def set_gguf_parameters(self):
14221446
self.gguf_writer.add_name("InternLM2")
14231447
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
@@ -1486,8 +1510,9 @@ def write_tensors(self):
14861510
qkv = data_torch
14871511
qkv = rearrange(qkv.T, " o (g n i) ->o g n i", g=num_groups, n=q_per_kv + 2, i=head_dim)
14881512
q, k, v = qkv[..., : q_per_kv, :], qkv[..., q_per_kv: q_per_kv + 1, :], qkv[..., q_per_kv + 1: q_per_kv + 2, :]
1489-
q = rearrange(q, " o g n i -> o (g n i)").T
1490-
k = rearrange(k, " o g n i -> o (g n i)").T
1513+
# The model weights of q and k equire additional reshape.
1514+
q = self._hf_permute_qk(rearrange(q, " o g n i -> o (g n i)").T, num_heads, num_heads)
1515+
k = self._hf_permute_qk(rearrange(k, " o g n i -> o (g n i)").T, num_heads, num_kv_heads)
14911516
v = rearrange(v, " o g n i -> o (g n i)").T
14921517
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wq.weight", q)
14931518
self.post_write_tensors(tensor_map, f"model.layers.{bid}.attention.wk.weight", k)

examples/server/README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,21 @@ Notice that each `probs` is an array of length `n_probs`.
264264

265265
It also accepts all the options of `/completion` except `stream` and `prompt`.
266266

267-
- **GET** `/props`: Return the required assistant name and anti-prompt to generate the prompt in case you have specified a system prompt for all slots.
267+
- **GET** `/props`: Return current server settings.
268+
269+
### Result JSON
270+
271+
```json
272+
{
273+
"assistant_name": "",
274+
"user_name": "",
275+
"default_generation_settings": { ... }
276+
}
277+
```
278+
279+
- `assistant_name` - the required assistant name to generate the prompt in case you have specified a system prompt for all slots.
280+
- `user_name` - the required anti-prompt to generate the prompt in case you have specified a system prompt for all slots.
281+
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, has the same fields as the `generation_settings` response object from the `/completion` endpoint.
268282

269283
- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only ChatML-tuned models, such as Dolphin, OpenOrca, OpenHermes, OpenChat-3.5, etc can be used with this endpoint. Compared to `api_like_OAI.py` this API implementation does not require a wrapper to be served.
270284

examples/server/server.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ struct llama_server_context
334334

335335
// slots / clients
336336
std::vector<llama_client_slot> slots;
337+
json default_generation_settings_for_props;
337338

338339
llama_server_queue queue_tasks;
339340
llama_server_response queue_results;
@@ -430,6 +431,9 @@ struct llama_server_context
430431
slots.push_back(slot);
431432
}
432433

434+
default_generation_settings_for_props = get_formated_generation(slots.front());
435+
default_generation_settings_for_props["seed"] = -1;
436+
433437
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
434438

435439
// empty system prompt
@@ -2614,7 +2618,8 @@ int main(int argc, char **argv)
26142618
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
26152619
json data = {
26162620
{ "user_name", llama.name_user.c_str() },
2617-
{ "assistant_name", llama.name_assistant.c_str() }
2621+
{ "assistant_name", llama.name_assistant.c_str() },
2622+
{ "default_generation_settings", llama.default_generation_settings_for_props }
26182623
};
26192624
res.set_content(data.dump(), "application/json; charset=utf-8");
26202625
});

ggml-impl.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ extern "C" {
1919
// fall back to the _Static_assert C11 keyword.
2020
// if C99 - static_assert is noop
2121
// ref: https://stackoverflow.com/a/53923785/4039976
22+
#ifndef __cplusplus
2223
#ifndef static_assert
2324
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
2425
#define static_assert(cond, msg) _Static_assert(cond, msg)
2526
#else
2627
#define static_assert(cond, msg) struct global_scope_noop_trick
2728
#endif
2829
#endif
30+
#endif
2931

3032
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
3133
#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))

ggml-quants.c

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -9048,8 +9048,6 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
90489048
int8_t L[32];
90499049
int8_t Laux[32];
90509050
float waux[32];
9051-
bool is_on_grid[4];
9052-
bool is_on_grid_aux[4];
90539051
uint8_t block_signs[4];
90549052
uint32_t q2[2*(QK_K/32)];
90559053

@@ -9099,10 +9097,11 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
90999097
memset(L, 0, 32);
91009098
continue;
91019099
}
9100+
float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight);
9101+
float eff_max = scale*kMaxQ;
91029102
float best = 0;
9103-
float scale = max/(2*kMaxQ-1);
9104-
for (int is = -9; is <= 9; ++is) {
9105-
float id = (2*kMaxQ-1+is*0.1f)/max;
9103+
for (int is = -6; is <= 6; ++is) {
9104+
float id = (2*kMaxQ-1+is*0.1f)/eff_max;
91069105
float this_scale = 1/id;
91079106
for (int k = 0; k < 4; ++k) {
91089107
for (int i = 0; i < 8; ++i) {
@@ -9112,9 +9111,7 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91129111
uint16_t u = 0;
91139112
for (int i = 0; i < 8; ++i) u |= (Laux[8*k+i] << 2*i);
91149113
int grid_index = kmap_q2xs[u];
9115-
is_on_grid_aux[k] = true;
91169114
if (grid_index < 0) {
9117-
is_on_grid_aux[k] = false;
91189115
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
91199116
grid_index = iq2_find_best_neighbour(neighbours, kgrid_q2xs, xval + 8*k, waux + 8*k, this_scale, Laux + 8*k);
91209117
}
@@ -9128,16 +9125,12 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91289125
}
91299126
if (sumq2 > 0 && sumqx*sumqx > best*sumq2) {
91309127
scale = sumqx/sumq2; best = scale*sumqx;
9131-
for (int i = 0; i < 32; ++i) L[i] = Laux[i];
9132-
for (int k = 0; k < 4; ++k) is_on_grid[k] = is_on_grid_aux[k];
9128+
memcpy(L, Laux, 32);
91339129
}
91349130
}
9135-
int n_not_ongrid = 0;
9136-
for (int k = 0; k < 4; ++k) if (!is_on_grid[k]) ++n_not_ongrid;
9137-
if (n_not_ongrid > 0 && scale > 0) {
9131+
if (scale > 0) {
91389132
float id = 1/scale;
91399133
for (int k = 0; k < 4; ++k) {
9140-
if (is_on_grid[k]) continue;
91419134
uint16_t u = 0;
91429135
for (int i = 0; i < 8; ++i) {
91439136
int l = nearest_int(0.5f*(id*xval[8*k+i]-1));
@@ -9193,49 +9186,10 @@ static void quantize_row_iq2_xxs_impl(const float * restrict x, void * restrict
91939186
float d = max_scale/31;
91949187
y[ibl].d = GGML_FP32_TO_FP16(d);
91959188
float id = 1/d;
9196-
float sumqx = 0, sumq2 = 0;
91979189
for (int ib = 0; ib < QK_K/32; ++ib) {
91989190
int l = nearest_int(0.5f*(id*scales[ib]-1));
91999191
l = MAX(0, MIN(15, l));
92009192
q2[2*ib+1] |= ((uint32_t)l << 28);
9201-
const float * xb = xbl + 32*ib;
9202-
const float * qw = quant_weights + QK_K*ibl + 32*ib;
9203-
for (int i = 0; i < 32; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
9204-
const uint8_t * aux8 = (const uint8_t *)(q2 + 2*ib);
9205-
const float db = d * (1 + 2*l);
9206-
uint32_t u = 0;
9207-
for (int k = 0; k < 4; ++k) {
9208-
const int8_t * signs = keven_signs_q2xs + 8*((q2[2*ib+1] >> 7*k) & 127);
9209-
const float * xk = xb + 8*k;
9210-
const float * wk = weight + 8*k;
9211-
const uint8_t * grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9212-
float best_mse = 0; int best_index = aux8[k];
9213-
for (int j = 0; j < 8; ++j) {
9214-
float diff = db * grid[j] * signs[j] - xk[j];
9215-
best_mse += wk[j] * diff * diff;
9216-
}
9217-
for (int idx = 0; idx < 256; ++idx) {
9218-
grid = (const uint8_t *)(kgrid_q2xs + idx);
9219-
float mse = 0;
9220-
for (int j = 0; j < 8; ++j) {
9221-
float diff = db * grid[j] * signs[j] - xk[j];
9222-
mse += wk[j] * diff * diff;
9223-
}
9224-
if (mse < best_mse) {
9225-
best_mse = mse; best_index = idx;
9226-
}
9227-
}
9228-
u |= (best_index << 8*k);
9229-
grid = (const uint8_t *)(kgrid_q2xs + best_index);
9230-
//grid = (const uint8_t *)(kgrid_q2xs + aux8[k]);
9231-
for (int j = 0; j < 8; ++j) {
9232-
float q = db * grid[j] * signs[j];
9233-
sumqx += wk[j] * q * xk[j];
9234-
sumq2 += wk[j] * q * q;
9235-
}
9236-
}
9237-
q2[2*ib] = u;
9238-
if (sumq2 > 0) y[ibl].d = GGML_FP32_TO_FP16(d*sumqx/sumq2);
92399193
}
92409194
memcpy(y[ibl].qs, q2, QK_K/4);
92419195
}

0 commit comments

Comments
 (0)