Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ int main(int argc, char ** argv) {

params.embedding = true;

// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
}

// utilize the full context
if (params.n_batch < params.n_ctx) {
LOG_WRN("%s: setting batch size to %d\n", __func__, params.n_ctx);
Expand Down
6 changes: 6 additions & 0 deletions examples/save-load-state/save-load-state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ int main(int argc, char ** argv) {
return 1;
}

if (params.n_parallel == 1) {
// the example uses 2 sequences, so when n_parallel == 1, we need to enable unified kv cache
printf("%s: n_parallel == 1, enabling unified kv cache\n", __func__);
params.kv_unified = true;
}

common_init();

if (params.n_predict < 0) {
Expand Down
2 changes: 1 addition & 1 deletion src/llama-batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ bool llama_batch_allocr::init(
for (int32_t i = 0; i < batch.n_tokens; ++i) {
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d >= %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
return false;
}
}
Expand Down
2 changes: 1 addition & 1 deletion tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ llama_build_and_test(test-json-partial.cpp)
llama_build_and_test(test-log.cpp)
llama_build_and_test(test-regex-partial.cpp)

llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4)
llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2)

# this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135)
if (NOT WIN32)
Expand Down
3 changes: 3 additions & 0 deletions tests/test-thread-safety.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ int main(int argc, char ** argv) {

auto cparams = common_context_params_to_llama(params);

// each context has a single sequence
cparams.n_seq_max = 1;

int dev_count = ggml_backend_dev_count();
int gpu_dev_count = 0;
for (int i = 0; i < dev_count; ++i) {
Expand Down
Loading