Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1348,9 +1348,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
));
add_opt(common_arg(
{"--prio"}, "N",
string_format("set process/thread priority : 0-normal, 1-medium, 2-high, 3-realtime (default: %d)\n", params.cpuparams.priority),
string_format("set process/thread priority : low(-1), normal(0), medium(1), high(2), realtime(3) (default: %d)\n", params.cpuparams.priority),
[](common_params & params, int prio) {
if (prio < 0 || prio > 3) {
if (prio < GGML_SCHED_PRIO_LOW || prio > GGML_SCHED_PRIO_REALTIME) {
throw std::invalid_argument("invalid value");
}
params.cpuparams.priority = (enum ggml_sched_priority) prio;
Expand Down
7 changes: 4 additions & 3 deletions common/chat-parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,10 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
if (!rest.empty()) {
handle_reasoning(rest, /* closed */ !is_partial());
}
if (!syntax_.thinking_forced_open) {
throw common_chat_msg_partial_exception(end_think);
}
// Allow unclosed thinking tags, for now (https://github.com/ggml-org/llama.cpp/issues/13812, https://github.com/ggml-org/llama.cpp/issues/13877)
// if (!syntax_.thinking_forced_open) {
// throw common_chat_msg_partial_exception(end_think);
// }
return true;
}
}
Expand Down
2 changes: 2 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {

DWORD p = NORMAL_PRIORITY_CLASS;
switch (prio) {
case GGML_SCHED_PRIO_LOW: p = BELOW_NORMAL_PRIORITY_CLASS; break;
case GGML_SCHED_PRIO_NORMAL: p = NORMAL_PRIORITY_CLASS; break;
case GGML_SCHED_PRIO_MEDIUM: p = ABOVE_NORMAL_PRIORITY_CLASS; break;
case GGML_SCHED_PRIO_HIGH: p = HIGH_PRIORITY_CLASS; break;
Expand All @@ -228,6 +229,7 @@ bool set_process_priority(enum ggml_sched_priority prio) {

int p = 0;
switch (prio) {
case GGML_SCHED_PRIO_LOW: p = 5; break;
case GGML_SCHED_PRIO_NORMAL: p = 0; break;
case GGML_SCHED_PRIO_MEDIUM: p = -5; break;
case GGML_SCHED_PRIO_HIGH: p = -10; break;
Expand Down
1 change: 1 addition & 0 deletions docs/build.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ cmake --build build --config Release
cmake --preset x64-windows-llvm-release
cmake --build build-x64-windows-llvm-release
```
- Curl usage is enabled by default and can be turned off with `-DLLAMA_CURL=OFF`. Otherwise you need to install development libraries for libcurl.
## BLAS Build
Expand Down
15 changes: 11 additions & 4 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,15 +362,17 @@ int main(int argc, char ** argv) {
// process in chunks of params.n_batch
int32_t n_batch = params.n_batch;

for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
int32_t i_next = 0;

for (int32_t i = 0; i < batch.n_tokens; i = i_next) {
// experiment: process in powers of 2
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
// n_batch /= 2;
// i -= n_batch;
// continue;
//}

const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);

llama_batch batch_view = {
n_tokens,
Expand All @@ -390,19 +392,24 @@ int main(int argc, char ** argv) {
return 1;
}

LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
LOG_WRN("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);

n_cache_miss += 1;

// retry with half the batch size to try to find a free slot in the KV cache
n_batch /= 2;
i -= n_batch;

continue;
}

LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);

// move the head of the batch forward with the number of tokens we just processed
i_next = i + n_tokens;

// on successful decode, restore the original batch size
n_batch = params.n_batch;

for (auto & client : clients) {
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
continue;
Expand Down
9 changes: 2 additions & 7 deletions examples/passkey/passkey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,8 @@ int main(int argc, char ** argv) {
const int ib = i/n_batch - 1;
const int bd = n_batch_grp*(n_grp - 1);

llama_kv_self_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_self_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_self_update (ctx);
llama_kv_self_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_self_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);

n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
}
Expand Down Expand Up @@ -169,8 +168,6 @@ int main(int argc, char ** argv) {

llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_self_defrag (ctx);
llama_kv_self_update (ctx);

n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;

Expand Down Expand Up @@ -200,8 +197,6 @@ int main(int argc, char ** argv) {

llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
//llama_kv_self_defrag (ctx);
llama_kv_self_update (ctx);

n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
}
Expand Down
1 change: 1 addition & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -2181,6 +2181,7 @@ extern "C" {

// scheduling priorities
enum ggml_sched_priority {
GGML_SCHED_PRIO_LOW = -1,
GGML_SCHED_PRIO_NORMAL,
GGML_SCHED_PRIO_MEDIUM,
GGML_SCHED_PRIO_HIGH,
Expand Down
23 changes: 23 additions & 0 deletions ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -2418,12 +2418,32 @@ static bool ggml_thread_apply_priority(int32_t prio) {
// This is up to the applications.
DWORD p = THREAD_PRIORITY_NORMAL;
switch (prio) {
case GGML_SCHED_PRIO_LOW: p = THREAD_PRIORITY_BELOW_NORMAL; break;
case GGML_SCHED_PRIO_NORMAL: p = THREAD_PRIORITY_NORMAL; break;
case GGML_SCHED_PRIO_MEDIUM: p = THREAD_PRIORITY_ABOVE_NORMAL; break;
case GGML_SCHED_PRIO_HIGH: p = THREAD_PRIORITY_HIGHEST; break;
case GGML_SCHED_PRIO_REALTIME: p = THREAD_PRIORITY_TIME_CRITICAL; break;
}

if (prio != GGML_SCHED_PRIO_LOW) {
// Tell Windows that this thread should not be throttled (needs its own CPU core).
// Newer Windows 11 versions aggresively park (offline) CPU cores and often place
// all our threads onto the first 4 cores which results in terrible performance with
// n_threads > 4
#if _WIN32_WINNT >= 0x0602
THREAD_POWER_THROTTLING_STATE t;
ZeroMemory(&t, sizeof(t));
t.Version = THREAD_POWER_THROTTLING_CURRENT_VERSION;
t.ControlMask = THREAD_POWER_THROTTLING_EXECUTION_SPEED;
t.StateMask = 0;

if (!SetThreadInformation(GetCurrentThread(), ThreadPowerThrottling, &t, sizeof(t))) {
GGML_LOG_DEBUG("failed to disable thread power throttling %d : (%d)\n", prio, (int) GetLastError());
return false;
}
#endif
}

if (prio == GGML_SCHED_PRIO_NORMAL) {
// Keep inherited policy/priority
return true;
Expand Down Expand Up @@ -2451,6 +2471,8 @@ static bool ggml_thread_apply_priority(int32_t prio) {
struct sched_param p;
int32_t policy = SCHED_OTHER;
switch (prio) {
// TODO: there seems to be no way to set lower prio on Apple platforms
case GGML_SCHED_PRIO_LOW: policy = SCHED_OTHER; p.sched_priority = 0; break;
case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
Expand Down Expand Up @@ -2507,6 +2529,7 @@ static bool ggml_thread_apply_priority(int32_t prio) {
struct sched_param p;
int32_t policy = SCHED_OTHER;
switch (prio) {
case GGML_SCHED_PRIO_LOW: policy = SCHED_BATCH; p.sched_priority = 0; break;
case GGML_SCHED_PRIO_NORMAL: policy = SCHED_OTHER; p.sched_priority = 0; break;
case GGML_SCHED_PRIO_MEDIUM: policy = SCHED_FIFO; p.sched_priority = 40; break;
case GGML_SCHED_PRIO_HIGH: policy = SCHED_FIFO; p.sched_priority = 80; break;
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,7 @@ struct ggml_cuda_device_info {
int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
bool integrated; // Device is integrated as opposed to discrete
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
Expand Down
20 changes: 14 additions & 6 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ static ggml_cuda_device_info ggml_cuda_init() {

info.default_tensor_split[id] = total_vram;
total_vram += prop.totalGlobalMem;

info.devices[id].nsm = prop.multiProcessorCount;
info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].warp_size = prop.warpSize;
info.devices[id].integrated = prop.integrated;
info.devices[id].nsm = prop.multiProcessorCount;
info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].warp_size = prop.warpSize;
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
info.devices[id].smpbo = prop.sharedMemPerBlock;

Expand Down Expand Up @@ -1065,6 +1065,10 @@ static const char * ggml_backend_cuda_host_buffer_type_name(ggml_backend_buffer_
GGML_UNUSED(buft);
}

static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cuda_host_buffer_type_name;
}

static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
CUDA_CHECK(cudaFreeHost(buffer->context));
}
Expand Down Expand Up @@ -2641,6 +2645,8 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {

static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;

while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
Expand All @@ -2659,7 +2665,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer);
assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft));
ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
}
}
#endif
Expand Down Expand Up @@ -3266,7 +3272,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
}

static bool ggml_backend_cuda_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
return (ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev;
ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context;
const bool integrated = ggml_cuda_info().devices[dev_ctx->device].integrated;
return (((ggml_backend_buft_is_cuda(buft) || ggml_backend_buft_is_cuda_split(buft)) && buft->device == dev) || (integrated && ggml_backend_buft_is_cuda_host(buft)));
}

static int64_t get_op_batch_size(const ggml_tensor * op) {
Expand Down
20 changes: 12 additions & 8 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,9 +259,9 @@ extern "C" {
llama_token * token;
float * embd;
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * logits; // TODO: rename this to "output"
int32_t * n_seq_id; // TODO: remove, should belong to only 1 sequence
llama_seq_id ** seq_id; // TODO: become llama_seq_id * seq_id;
int8_t * logits; // TODO: rename this to "output"
} llama_batch;

enum llama_model_kv_override_type {
Expand Down Expand Up @@ -366,6 +366,8 @@ extern "C" {
bool no_perf; // measure performance timings
bool op_offload; // offload host tensor operations to device
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
};

// model quantization parameters
Expand Down Expand Up @@ -502,6 +504,7 @@ extern "C" {
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);

// Get the model's RoPE frequency scaling factor
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
Expand Down Expand Up @@ -652,7 +655,6 @@ extern "C" {
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_add(
Expand All @@ -665,7 +667,6 @@ extern "C" {
// Integer division of the positions by factor of `d > 1`
// If the KV cache is RoPEd, the KV data is updated accordingly:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
// p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf)
LLAMA_API void llama_kv_self_seq_div(
Expand All @@ -677,12 +678,14 @@ extern "C" {

// Returns the smallest position present in the KV cache for the specified sequence
// This is typically non-zero only for SWA caches
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_min(
struct llama_context * ctx,
llama_seq_id seq_id);

// Returns the largest position present in the KV cache for the specified sequence
// Note that all positions in the range [pos_min, pos_max] are guaranteed to be present in the KV cache
// Return -1 if the sequence is empty
LLAMA_API llama_pos llama_kv_self_seq_pos_max(
struct llama_context * ctx,
Expand All @@ -691,14 +694,15 @@ extern "C" {
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_self_update()
LLAMA_API void llama_kv_self_defrag(struct llama_context * ctx);
LLAMA_API DEPRECATED(void llama_kv_self_defrag(struct llama_context * ctx),
"simply remove this call, the context will automatically decide when to do a defragmentation based on 'defrag_thold'");

// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_self_can_shift(const struct llama_context * ctx);

// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_self_update(struct llama_context * ctx);
LLAMA_API DEPRECATED(void llama_kv_self_update(struct llama_context * ctx),
"simply remove this call, updates are applied lazily on the next llama_decode()");

//
// State / sessions
Expand Down
31 changes: 19 additions & 12 deletions src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,31 @@ llama_ubatch llama_sbatch::reserve_ubatch(size_t n_ubatch, bool has_embd) {
break;
}
}
ubatch_token.resize(!has_embd ? n_ubatch : 0);
ubatch_embd.resize(has_embd ? n_embd * n_ubatch : 0);
ubatch_pos.resize(n_ubatch);
ubatch_n_seq_id.resize(n_ubatch);
ubatch_seq_id.resize(n_ubatch);
ubatch_output.resize(n_ubatch);

udatas.push_back({});

auto & udata = udatas.back();

udata.token.resize(!has_embd ? n_ubatch : 0);
udata.embd.resize(has_embd ? n_embd * n_ubatch : 0);
udata.pos.resize(n_ubatch);
udata.n_seq_id.resize(n_ubatch);
udata.seq_id.resize(n_ubatch);
udata.output.resize(n_ubatch);

llama_ubatch ubatch = {
/*equal_seqs =*/ true,
/*n_tokens =*/ 0,
/*n_seq_tokens =*/ 0,
/*n_seqs =*/ 0,
/*token =*/ !has_embd ? ubatch_token.data() : nullptr,
/*embd =*/ has_embd ? ubatch_embd.data() : nullptr,
/*pos =*/ ubatch_pos.data(),
/*n_seq_id =*/ ubatch_n_seq_id.data(),
/*seq_id =*/ ubatch_seq_id.data(),
/*output =*/ ubatch_output.data(),
/*token =*/ !has_embd ? udata.token.data() : nullptr,
/*embd =*/ has_embd ? udata.embd.data() : nullptr,
/*pos =*/ udata.pos.data(),
/*n_seq_id =*/ udata.n_seq_id.data(),
/*seq_id =*/ udata.seq_id.data(),
/*output =*/ udata.output.data(),
};

return ubatch;
}

Expand Down
Loading
Loading