diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 91719577564a9..95727ed16fb44 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,5 +1,20 @@ llama_add_compile_flags() +# ThreadSanitizer configuration for race condition detection +option(LLAMA_SANITIZE_THREAD "Enable ThreadSanitizer for race condition detection" OFF) + +if (LLAMA_SANITIZE_THREAD) + if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + add_compile_options(-fsanitize=thread -g -O1) + add_link_options(-fsanitize=thread) + message(STATUS "ThreadSanitizer enabled for concurrent testing") + + set(ENV{TSAN_OPTIONS} "halt_on_error=1:second_deadlock_stack=1") + else() + message(WARNING "ThreadSanitizer is only supported with GCC or Clang") + endif() +endif() + function(llama_build source) if (DEFINED LLAMA_TEST_NAME) set(TEST_TARGET ${LLAMA_TEST_NAME}) @@ -187,6 +202,10 @@ llama_build_and_test(test-regex-partial.cpp) llama_build_and_test(test-thread-safety.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 128 -c 256 -ub 32 -np 4 -t 2) +llama_build_and_test(test-concurrent-stress.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 32 -c 512 -np 4 -t 2 LABEL "stress") + +llama_build_and_test(test-kv-cache-concurrent.cpp ARGS -hf ggml-org/models -hff tinyllamas/stories15M-q4_0.gguf -ngl 99 -p "The meaning of life is" -n 32 -c 1024 -np 4 -t 2 LABEL "stress") + # this fails on windows (github hosted runner) due to curl DLL not found (exit code 0xc0000135) if (NOT WIN32) llama_build_and_test(test-arg-parser.cpp) diff --git a/tests/test-concurrent-stress.cpp b/tests/test-concurrent-stress.cpp new file mode 100644 index 0000000000000..0899b68cfd67b --- /dev/null +++ b/tests/test-concurrent-stress.cpp @@ -0,0 +1,412 @@ + +#include +#include +#include +#include +#include +#include +#include "llama.h" +#include "arg.h" +#include "common.h" +#include "log.h" +#include "sampling.h" + +static std::atomic g_contexts_created{0}; +static std::atomic g_contexts_destroyed{0}; +static std::atomic g_decode_operations{0}; +static std::atomic g_errors{0}; + +struct stress_test_result { + int contexts_created = 0; + int contexts_destroyed = 0; + int decode_operations = 0; + int errors = 0; + double duration_seconds = 0.0; +}; + +static void rapid_context_lifecycle_test( + llama_model * model, + const llama_context_params & cparams, + const common_params & params, + int iterations) { + + for (int i = 0; i < iterations; ++i) { + llama_context * ctx = llama_init_from_model(model, cparams); + if (ctx == NULL) { + LOG_ERR("failed to create context in rapid lifecycle test\n"); + g_errors++; + continue; + } + g_contexts_created++; + + std::unique_ptr sampler { + common_sampler_init(model, params.sampling), common_sampler_free + }; + if (sampler == NULL) { + LOG_ERR("failed to create sampler in rapid lifecycle test\n"); + g_errors++; + llama_free(ctx); + continue; + } + + auto prompt = common_tokenize(ctx, params.prompt, true); + if (!prompt.empty()) { + llama_batch batch = llama_batch_get_one(prompt.data(), prompt.size()); + if (llama_decode(ctx, batch) == 0) { + g_decode_operations++; + } else { + g_errors++; + } + } + + llama_free(ctx); + g_contexts_destroyed++; + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +static void sustained_inference_test( + llama_model * model, + const llama_context_params & cparams, + const common_params & params, + int num_iterations) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (ctx == NULL) { + LOG_ERR("failed to create context in sustained inference test\n"); + g_errors++; + return; + } + g_contexts_created++; + + std::unique_ptr sampler { + common_sampler_init(model, params.sampling), common_sampler_free + }; + if (sampler == NULL) { + LOG_ERR("failed to create sampler in sustained inference test\n"); + g_errors++; + llama_free(ctx); + return; + } + + const auto * vocab = llama_model_get_vocab(model); + + for (int iter = 0; iter < num_iterations; ++iter) { + auto prompt = common_tokenize(ctx, params.prompt, true); + if (prompt.empty()) { + g_errors++; + continue; + } + + llama_batch batch = llama_batch_get_one(prompt.data(), prompt.size()); + if (llama_decode(ctx, batch)) { + g_errors++; + continue; + } + g_decode_operations++; + + for (int i = 0; i < 10; i++) { + llama_token token; + if (batch.n_tokens > 0) { + token = common_sampler_sample(sampler.get(), ctx, batch.n_tokens - 1); + } else { + token = llama_vocab_bos(vocab); + } + + if (llama_vocab_is_eog(vocab, token)) { + break; + } + + batch = llama_batch_get_one(&token, 1); + if (llama_decode(ctx, batch)) { + g_errors++; + break; + } + g_decode_operations++; + } + + llama_memory_clear(llama_get_memory(ctx), false); + } + + llama_free(ctx); + g_contexts_destroyed++; +} + +static void concurrent_sequence_test( + llama_model * model, + const llama_context_params & cparams, + const common_params & params, + int num_sequences) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (ctx == NULL) { + LOG_ERR("failed to create context in concurrent sequence test\n"); + g_errors++; + return; + } + g_contexts_created++; + + std::unique_ptr sampler { + common_sampler_init(model, params.sampling), common_sampler_free + }; + if (sampler == NULL) { + LOG_ERR("failed to create sampler in concurrent sequence test\n"); + g_errors++; + llama_free(ctx); + return; + } + + const auto * vocab = llama_model_get_vocab(model); + + for (int seq_id = 0; seq_id < num_sequences; ++seq_id) { + auto prompt = common_tokenize(ctx, params.prompt, true); + if (prompt.empty()) { + g_errors++; + continue; + } + + llama_batch batch = llama_batch_init(prompt.size(), 0, 1); + for (size_t i = 0; i < prompt.size(); ++i) { + batch.token[i] = prompt[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = seq_id; + batch.logits[i] = (i == prompt.size() - 1); + } + batch.n_tokens = prompt.size(); + + if (llama_decode(ctx, batch)) { + g_errors++; + llama_batch_free(batch); + continue; + } + g_decode_operations++; + + for (int i = 0; i < 5; i++) { + llama_token token = common_sampler_sample(sampler.get(), ctx, batch.n_tokens - 1); + + if (llama_vocab_is_eog(vocab, token)) { + break; + } + + batch.n_tokens = 1; + batch.token[0] = token; + batch.pos[0] = prompt.size() + i; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = seq_id; + batch.logits[0] = true; + + if (llama_decode(ctx, batch)) { + g_errors++; + break; + } + g_decode_operations++; + } + + llama_batch_free(batch); + llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1); + } + + llama_free(ctx); + g_contexts_destroyed++; +} + +static void memory_stress_test( + llama_model * model, + const llama_context_params & cparams, + const common_params & params, + int num_operations) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (ctx == NULL) { + LOG_ERR("failed to create context in memory stress test\n"); + g_errors++; + return; + } + g_contexts_created++; + + std::unique_ptr sampler { + common_sampler_init(model, params.sampling), common_sampler_free + }; + if (sampler == NULL) { + LOG_ERR("failed to create sampler in memory stress test\n"); + g_errors++; + llama_free(ctx); + return; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> seq_dist(0, 15); + + for (int op = 0; op < num_operations; ++op) { + int seq_id = seq_dist(gen); + + auto prompt = common_tokenize(ctx, params.prompt, true); + if (!prompt.empty()) { + llama_batch batch = llama_batch_init(prompt.size(), 0, 1); + for (size_t i = 0; i < prompt.size(); ++i) { + batch.token[i] = prompt[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = seq_id; + batch.logits[i] = (i == prompt.size() - 1); + } + batch.n_tokens = prompt.size(); + + if (llama_decode(ctx, batch) == 0) { + g_decode_operations++; + } else { + g_errors++; + } + + llama_batch_free(batch); + } + + if (op % 3 == 0) { + llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1); + } else if (op % 3 == 1) { + int target_seq = (seq_id + 1) % 16; + llama_memory_seq_cp(llama_get_memory(ctx), seq_id, target_seq, -1, -1); + } else { + llama_memory_clear(llama_get_memory(ctx), false); + } + } + + llama_free(ctx); + g_contexts_destroyed++; +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("Starting concurrent stress tests...\n"); + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), common_model_params_to_llama(params)); + if (model == NULL) { + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); + return 1; + } + + auto cparams = common_context_params_to_llama(params); + cparams.n_seq_max = std::max(16u, cparams.n_seq_max); + + const int num_threads = std::max(1, params.n_parallel); + const int iterations_per_thread = 5; + + g_contexts_created = 0; + g_contexts_destroyed = 0; + g_decode_operations = 0; + g_errors = 0; + + auto start_time = std::chrono::high_resolution_clock::now(); + + LOG_INF("\n=== Test 1: Rapid Context Lifecycle (%d threads, %d iterations each) ===\n", + num_threads, iterations_per_thread); + { + std::vector threads; + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back(rapid_context_lifecycle_test, model, cparams, params, iterations_per_thread); + } + for (auto & thread : threads) { + thread.join(); + } + } + LOG_INF("Contexts created: %d, destroyed: %d, decode ops: %d, errors: %d\n", + g_contexts_created.load(), g_contexts_destroyed.load(), + g_decode_operations.load(), g_errors.load()); + + g_contexts_created = 0; + g_contexts_destroyed = 0; + g_decode_operations = 0; + int errors_after_test1 = g_errors.load(); + + LOG_INF("\n=== Test 2: Sustained Concurrent Inference (%d threads, %d iterations each) ===\n", + num_threads, iterations_per_thread * 2); + { + std::vector threads; + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back(sustained_inference_test, model, cparams, params, iterations_per_thread * 2); + } + for (auto & thread : threads) { + thread.join(); + } + } + LOG_INF("Contexts created: %d, destroyed: %d, decode ops: %d, errors: %d\n", + g_contexts_created.load(), g_contexts_destroyed.load(), + g_decode_operations.load(), g_errors.load()); + + g_contexts_created = 0; + g_contexts_destroyed = 0; + g_decode_operations = 0; + int errors_after_test2 = g_errors.load(); + + LOG_INF("\n=== Test 3: Concurrent Sequence Operations (%d threads, %d sequences each) ===\n", + num_threads / 2, 8); + { + std::vector threads; + for (int t = 0; t < std::max(1, num_threads / 2); ++t) { + threads.emplace_back(concurrent_sequence_test, model, cparams, params, 8); + } + for (auto & thread : threads) { + thread.join(); + } + } + LOG_INF("Contexts created: %d, destroyed: %d, decode ops: %d, errors: %d\n", + g_contexts_created.load(), g_contexts_destroyed.load(), + g_decode_operations.load(), g_errors.load()); + + g_contexts_created = 0; + g_contexts_destroyed = 0; + g_decode_operations = 0; + int errors_after_test3 = g_errors.load(); + + LOG_INF("\n=== Test 4: Memory Operations Stress (%d threads, %d operations each) ===\n", + num_threads, iterations_per_thread * 3); + { + std::vector threads; + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back(memory_stress_test, model, cparams, params, iterations_per_thread * 3); + } + for (auto & thread : threads) { + thread.join(); + } + } + LOG_INF("Contexts created: %d, destroyed: %d, decode ops: %d, errors: %d\n", + g_contexts_created.load(), g_contexts_destroyed.load(), + g_decode_operations.load(), g_errors.load()); + + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time); + + int total_errors = g_errors.load(); + + LOG_INF("\n=== Stress Test Summary ===\n"); + LOG_INF("Total duration: %.2f seconds\n", duration.count() / 1000.0); + LOG_INF("Total errors: %d\n", total_errors); + LOG_INF(" After test 1: %d\n", errors_after_test1); + LOG_INF(" After test 2: %d\n", errors_after_test2); + LOG_INF(" After test 3: %d\n", errors_after_test3); + LOG_INF(" After test 4: %d\n", total_errors); + + llama_model_free(model); + + if (total_errors > 0) { + LOG_ERR("Stress tests completed with %d errors\n", total_errors); + return 1; + } + + LOG_INF("All stress tests passed successfully!\n"); + return 0; +} diff --git a/tests/test-kv-cache-concurrent.cpp b/tests/test-kv-cache-concurrent.cpp new file mode 100644 index 0000000000000..f33009f3472ac --- /dev/null +++ b/tests/test-kv-cache-concurrent.cpp @@ -0,0 +1,492 @@ + +#include +#include +#include +#include +#include +#include "llama.h" +#include "arg.h" +#include "common.h" +#include "log.h" +#include "sampling.h" + +static std::atomic g_cache_operations{0}; +static std::atomic g_slot_allocations{0}; +static std::atomic g_slot_deallocations{0}; +static std::atomic g_errors{0}; +static std::atomic g_stop_flag{false}; + +static void concurrent_cache_alloc_dealloc_test( + llama_model * model, + const llama_context_params & cparams, + const common_params & params, + int num_iterations, + int thread_id) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (ctx == NULL) { + LOG_ERR("Thread %d: failed to create context\n", thread_id); + g_errors++; + return; + } + + std::unique_ptr sampler { + common_sampler_init(model, params.sampling), common_sampler_free + }; + if (sampler == NULL) { + LOG_ERR("Thread %d: failed to create sampler\n", thread_id); + g_errors++; + llama_free(ctx); + return; + } + + const auto * vocab = llama_model_get_vocab(model); + std::random_device rd; + std::mt19937 gen(rd() + thread_id); + std::uniform_int_distribution<> seq_dist(thread_id * 4, thread_id * 4 + 3); + + for (int iter = 0; iter < num_iterations && !g_stop_flag; ++iter) { + int seq_id = seq_dist(gen); + + auto prompt = common_tokenize(ctx, params.prompt, true); + if (prompt.empty()) { + g_errors++; + continue; + } + + llama_batch batch = llama_batch_init(prompt.size(), 0, 1); + for (size_t i = 0; i < prompt.size(); ++i) { + batch.token[i] = prompt[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = seq_id; + batch.logits[i] = (i == prompt.size() - 1); + } + batch.n_tokens = prompt.size(); + + if (llama_decode(ctx, batch)) { + g_errors++; + llama_batch_free(batch); + continue; + } + g_cache_operations++; + g_slot_allocations++; + + for (int i = 0; i < 5; i++) { + llama_token token = common_sampler_sample(sampler.get(), ctx, batch.n_tokens - 1); + + if (llama_vocab_is_eog(vocab, token)) { + break; + } + + batch.n_tokens = 1; + batch.token[0] = token; + batch.pos[0] = prompt.size() + i; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = seq_id; + batch.logits[0] = true; + + if (llama_decode(ctx, batch)) { + g_errors++; + break; + } + g_cache_operations++; + } + + llama_batch_free(batch); + + llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1); + g_slot_deallocations++; + + if (iter % 10 == 0) { + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + } + + llama_free(ctx); +} + +static void concurrent_sequence_copy_test( + llama_model * model, + const llama_context_params & cparams, + const common_params & params, + int num_iterations, + int thread_id) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (ctx == NULL) { + LOG_ERR("Thread %d: failed to create context for sequence copy test\n", thread_id); + g_errors++; + return; + } + + std::unique_ptr sampler { + common_sampler_init(model, params.sampling), common_sampler_free + }; + if (sampler == NULL) { + LOG_ERR("Thread %d: failed to create sampler for sequence copy test\n", thread_id); + g_errors++; + llama_free(ctx); + return; + } + + const auto * vocab = llama_model_get_vocab(model); + std::random_device rd; + std::mt19937 gen(rd() + thread_id + 1000); + std::uniform_int_distribution<> seq_dist(thread_id * 3, thread_id * 3 + 2); + + for (int iter = 0; iter < num_iterations && !g_stop_flag; ++iter) { + int src_seq = seq_dist(gen); + int dst_seq = (src_seq + 1) % (thread_id * 3 + 3); + + auto prompt = common_tokenize(ctx, params.prompt, true); + if (prompt.empty()) { + g_errors++; + continue; + } + + llama_batch batch = llama_batch_init(prompt.size(), 0, 1); + for (size_t i = 0; i < prompt.size(); ++i) { + batch.token[i] = prompt[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = src_seq; + batch.logits[i] = (i == prompt.size() - 1); + } + batch.n_tokens = prompt.size(); + + if (llama_decode(ctx, batch)) { + g_errors++; + llama_batch_free(batch); + continue; + } + g_cache_operations++; + + llama_token token = common_sampler_sample(sampler.get(), ctx, batch.n_tokens - 1); + if (!llama_vocab_is_eog(vocab, token)) { + batch.n_tokens = 1; + batch.token[0] = token; + batch.pos[0] = prompt.size(); + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = src_seq; + batch.logits[0] = true; + + if (llama_decode(ctx, batch)) { + g_errors++; + } else { + g_cache_operations++; + } + } + + llama_batch_free(batch); + + llama_memory_seq_cp(llama_get_memory(ctx), src_seq, dst_seq, -1, -1); + g_cache_operations++; + + batch = llama_batch_init(1, 0, 1); + batch.n_tokens = 1; + batch.token[0] = token; + batch.pos[0] = prompt.size() + 1; + batch.n_seq_id[0] = 1; + batch.seq_id[0][0] = dst_seq; + batch.logits[0] = true; + + if (llama_decode(ctx, batch)) { + g_errors++; + } else { + g_cache_operations++; + } + + llama_batch_free(batch); + + llama_memory_seq_rm(llama_get_memory(ctx), src_seq, -1, -1); + llama_memory_seq_rm(llama_get_memory(ctx), dst_seq, -1, -1); + + if (iter % 5 == 0) { + std::this_thread::sleep_for(std::chrono::microseconds(50)); + } + } + + llama_free(ctx); +} + +static void concurrent_cache_clear_test( + llama_model * model, + const llama_context_params & cparams, + const common_params & params, + int num_iterations, + int thread_id) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (ctx == NULL) { + LOG_ERR("Thread %d: failed to create context for cache clear test\n", thread_id); + g_errors++; + return; + } + + std::unique_ptr sampler { + common_sampler_init(model, params.sampling), common_sampler_free + }; + if (sampler == NULL) { + LOG_ERR("Thread %d: failed to create sampler for cache clear test\n", thread_id); + g_errors++; + llama_free(ctx); + return; + } + + for (int iter = 0; iter < num_iterations && !g_stop_flag; ++iter) { + for (int seq_id = thread_id * 2; seq_id < thread_id * 2 + 2; ++seq_id) { + auto prompt = common_tokenize(ctx, params.prompt, true); + if (prompt.empty()) { + g_errors++; + continue; + } + + llama_batch batch = llama_batch_init(prompt.size(), 0, 1); + for (size_t i = 0; i < prompt.size(); ++i) { + batch.token[i] = prompt[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = seq_id; + batch.logits[i] = (i == prompt.size() - 1); + } + batch.n_tokens = prompt.size(); + + if (llama_decode(ctx, batch)) { + g_errors++; + llama_batch_free(batch); + continue; + } + g_cache_operations++; + + llama_batch_free(batch); + } + + llama_memory_clear(llama_get_memory(ctx), false); + g_cache_operations++; + + if (iter % 3 == 0) { + std::this_thread::sleep_for(std::chrono::microseconds(200)); + } + } + + llama_free(ctx); +} + +static void concurrent_mixed_operations_test( + llama_model * model, + const llama_context_params & cparams, + const common_params & params, + int num_iterations, + int thread_id) { + + llama_context * ctx = llama_init_from_model(model, cparams); + if (ctx == NULL) { + LOG_ERR("Thread %d: failed to create context for mixed operations test\n", thread_id); + g_errors++; + return; + } + + std::unique_ptr sampler { + common_sampler_init(model, params.sampling), common_sampler_free + }; + if (sampler == NULL) { + LOG_ERR("Thread %d: failed to create sampler for mixed operations test\n", thread_id); + g_errors++; + llama_free(ctx); + return; + } + + std::random_device rd; + std::mt19937 gen(rd() + thread_id + 2000); + std::uniform_int_distribution<> op_dist(0, 3); + std::uniform_int_distribution<> seq_dist(thread_id * 2, thread_id * 2 + 1); + + for (int iter = 0; iter < num_iterations && !g_stop_flag; ++iter) { + int operation = op_dist(gen); + int seq_id = seq_dist(gen); + + switch (operation) { + case 0: { + auto prompt = common_tokenize(ctx, params.prompt, true); + if (!prompt.empty()) { + llama_batch batch = llama_batch_init(prompt.size(), 0, 1); + for (size_t i = 0; i < prompt.size(); ++i) { + batch.token[i] = prompt[i]; + batch.pos[i] = i; + batch.n_seq_id[i] = 1; + batch.seq_id[i][0] = seq_id; + batch.logits[i] = (i == prompt.size() - 1); + } + batch.n_tokens = prompt.size(); + + if (llama_decode(ctx, batch) == 0) { + g_cache_operations++; + } else { + g_errors++; + } + + llama_batch_free(batch); + } + break; + } + case 1: { + int target_seq = (seq_id + 1) % (thread_id * 2 + 2); + llama_memory_seq_cp(llama_get_memory(ctx), seq_id, target_seq, -1, -1); + g_cache_operations++; + break; + } + case 2: { + llama_memory_seq_rm(llama_get_memory(ctx), seq_id, -1, -1); + g_cache_operations++; + break; + } + case 3: { + llama_pos min_pos = llama_memory_seq_pos_min(llama_get_memory(ctx), seq_id); + llama_pos max_pos = llama_memory_seq_pos_max(llama_get_memory(ctx), seq_id); + if (min_pos >= 0 && max_pos >= min_pos) { + llama_memory_seq_rm(llama_get_memory(ctx), seq_id, min_pos, max_pos / 2); + g_cache_operations++; + } + break; + } + } + + if (iter % 20 == 0) { + std::this_thread::sleep_for(std::chrono::microseconds(150)); + } + } + + llama_free(ctx); +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("Starting KV cache concurrent tests...\n"); + + llama_model * model = llama_model_load_from_file(params.model.path.c_str(), common_model_params_to_llama(params)); + if (model == NULL) { + LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); + return 1; + } + + auto cparams = common_context_params_to_llama(params); + cparams.n_seq_max = std::max(32u, cparams.n_seq_max); + + const int num_threads = std::max(2, params.n_parallel); + const int iterations_per_test = 20; + + g_cache_operations = 0; + g_slot_allocations = 0; + g_slot_deallocations = 0; + g_errors = 0; + g_stop_flag = false; + + auto start_time = std::chrono::high_resolution_clock::now(); + + LOG_INF("\n=== Test 1: Concurrent Cache Allocation/Deallocation (%d threads, %d iterations) ===\n", + num_threads, iterations_per_test); + { + std::vector threads; + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back(concurrent_cache_alloc_dealloc_test, model, cparams, params, + iterations_per_test, t); + } + for (auto & thread : threads) { + thread.join(); + } + } + LOG_INF("Cache operations: %d, allocations: %d, deallocations: %d, errors: %d\n", + g_cache_operations.load(), g_slot_allocations.load(), + g_slot_deallocations.load(), g_errors.load()); + + int errors_after_test1 = g_errors.load(); + g_cache_operations = 0; + g_slot_allocations = 0; + g_slot_deallocations = 0; + + LOG_INF("\n=== Test 2: Concurrent Sequence Copy Operations (%d threads, %d iterations) ===\n", + num_threads, iterations_per_test); + { + std::vector threads; + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back(concurrent_sequence_copy_test, model, cparams, params, + iterations_per_test, t); + } + for (auto & thread : threads) { + thread.join(); + } + } + LOG_INF("Cache operations: %d, errors: %d\n", + g_cache_operations.load(), g_errors.load()); + + int errors_after_test2 = g_errors.load(); + g_cache_operations = 0; + + LOG_INF("\n=== Test 3: Concurrent Cache Clear Operations (%d threads, %d iterations) ===\n", + num_threads, iterations_per_test * 2); + { + std::vector threads; + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back(concurrent_cache_clear_test, model, cparams, params, + iterations_per_test * 2, t); + } + for (auto & thread : threads) { + thread.join(); + } + } + LOG_INF("Cache operations: %d, errors: %d\n", + g_cache_operations.load(), g_errors.load()); + + int errors_after_test3 = g_errors.load(); + g_cache_operations = 0; + + LOG_INF("\n=== Test 4: Mixed Concurrent Operations (%d threads, %d iterations) ===\n", + num_threads, iterations_per_test * 3); + { + std::vector threads; + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back(concurrent_mixed_operations_test, model, cparams, params, + iterations_per_test * 3, t); + } + for (auto & thread : threads) { + thread.join(); + } + } + LOG_INF("Cache operations: %d, errors: %d\n", + g_cache_operations.load(), g_errors.load()); + + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end_time - start_time); + + int total_errors = g_errors.load(); + + LOG_INF("\n=== KV Cache Concurrent Test Summary ===\n"); + LOG_INF("Total duration: %.2f seconds\n", duration.count() / 1000.0); + LOG_INF("Total errors: %d\n", total_errors); + LOG_INF(" After test 1: %d\n", errors_after_test1); + LOG_INF(" After test 2: %d\n", errors_after_test2); + LOG_INF(" After test 3: %d\n", errors_after_test3); + LOG_INF(" After test 4: %d\n", total_errors); + + llama_model_free(model); + + if (total_errors > 0) { + LOG_ERR("KV cache concurrent tests completed with %d errors\n", total_errors); + return 1; + } + + LOG_INF("All KV cache concurrent tests passed successfully!\n"); + return 0; +} diff --git a/tests/test-thread-safety.cpp b/tests/test-thread-safety.cpp index 853495b00d9d2..09ffebe7d2671 100644 --- a/tests/test-thread-safety.cpp +++ b/tests/test-thread-safety.cpp @@ -6,12 +6,22 @@ #include #include #include +#include +#include +#include #include "llama.h" #include "arg.h" #include "common.h" #include "log.h" #include "sampling.h" +static std::atomic g_context_init_count{0}; +static std::atomic g_decode_count{0}; +static std::atomic g_model_access_count{0}; +static std::mutex g_barrier_mutex; +static std::condition_variable g_barrier_cv; +static int g_threads_ready = 0; + int main(int argc, char ** argv) { common_params params; @@ -75,18 +85,42 @@ int main(int argc, char ** argv) { models.emplace_back(model); } + const int total_threads = num_models * num_contexts; + for (int m = 0; m < num_models; ++m) { auto * model = models[m].get(); for (int c = 0; c < num_contexts; ++c) { - threads.emplace_back([&, m, c, model]() { + threads.emplace_back([&, m, c, model, total_threads]() { LOG_INF("Creating context %d/%d for model %d/%d\n", c + 1, num_contexts, m + 1, num_models); + g_model_access_count++; + + { + std::unique_lock lock(g_barrier_mutex); + g_threads_ready++; + if (g_threads_ready == total_threads) { + g_barrier_cv.notify_all(); + } else { + g_barrier_cv.wait(lock, [&]{ return g_threads_ready == total_threads; }); + } + } + + auto start_time = std::chrono::steady_clock::now(); + llama_context_ptr ctx { llama_init_from_model(model, cparams) }; if (ctx == NULL) { LOG_ERR("failed to create context\n"); failed.store(true); return; } + g_context_init_count++; + + auto init_time = std::chrono::steady_clock::now(); + auto init_duration = std::chrono::duration_cast(init_time - start_time).count(); + if (init_duration > 5000) { + LOG_WRN("Model %d/%d, Context %d/%d: slow context initialization (%ld ms)\n", + m + 1, num_models, c + 1, num_contexts, (long)init_duration); + } std::unique_ptr sampler { common_sampler_init(model, params.sampling), common_sampler_free }; if (sampler == NULL) { @@ -109,6 +143,7 @@ int main(int argc, char ** argv) { failed.store(true); return; } + g_decode_count++; } const auto * vocab = llama_model_get_vocab(model); @@ -134,6 +169,11 @@ int main(int argc, char ** argv) { failed.store(true); return; } + g_decode_count++; + + if (i % 32 == 31) { + std::this_thread::yield(); + } } LOG_INF("Model %d/%d, Context %d/%d: %s\n\n", m + 1, num_models, c + 1, num_contexts, result.c_str()); @@ -145,6 +185,16 @@ int main(int argc, char ** argv) { thread.join(); } + LOG_INF("\n=== Thread Safety Test Statistics ===\n"); + LOG_INF("Total threads: %d\n", total_threads); + LOG_INF("Model access count: %d\n", g_model_access_count.load()); + LOG_INF("Context init count: %d\n", g_context_init_count.load()); + LOG_INF("Decode operation count: %d\n", g_decode_count.load()); + + if (g_context_init_count != total_threads) { + LOG_WRN("Warning: expected %d context inits, got %d\n", total_threads, g_context_init_count.load()); + } + if (failed) { LOG_ERR("One or more threads failed.\n"); return 1; diff --git a/tools/server/tests/unit/test_completion.py b/tools/server/tests/unit/test_completion.py index 11483e679a505..1783881d880d9 100644 --- a/tools/server/tests/unit/test_completion.py +++ b/tools/server/tests/unit/test_completion.py @@ -533,3 +533,148 @@ def test_cancel_request(): time.sleep(1) # wait for HTTP_POLLING_SECONDS res = server.make_request("GET", "/slots") assert res.body[0]["is_processing"] == False + + +@pytest.mark.parametrize("n_slots,n_requests", [ + (8, 32), + (4, 64), + (8, 128), +]) +def test_high_volume_concurrent_requests(n_slots: int, n_requests: int): + global server + server.n_slots = n_slots + server.n_ctx = 512 + server.temperature = 0.8 + server.start() + + prompts = [ + "Write a short story about", + "Explain the concept of", + "What is the best way to", + "Tell me about", + "How can I improve my", + "Describe a day in the life of", + "What are the benefits of", + "List three reasons why", + ] + + tasks = [] + for i in range(n_requests): + prompt = prompts[i % len(prompts)] + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": prompt, + "seed": 42 + i, + "n_predict": 16, + "temperature": 0.8, + }))) + + start_time = time.time() + results = parallel_function_calls(tasks) + duration = time.time() - start_time + + successful_requests = 0 + for res in results: + if res.status_code == 200 and "content" in res.body: + assert type(res.body["content"]) == str + successful_requests += 1 + + assert successful_requests == n_requests + throughput = n_requests / duration + print(f"High volume test: {n_requests} requests on {n_slots} slots in {duration:.2f}s ({throughput:.2f} req/s)") + + +@pytest.mark.parametrize("n_slots", [4, 8]) +def test_concurrent_streaming_requests(n_slots: int): + global server + server.n_slots = n_slots + server.n_ctx = 512 + server.start() + + def make_streaming_completion(prompt: str, seed: int): + res = server.make_stream_request("POST", "/completion", data={ + "prompt": prompt, + "seed": seed, + "n_predict": 24, + "stream": True, + }) + content = "" + for chunk in res: + if "content" in chunk: + content += chunk["content"] + return content + + prompts = [ + ("Write something interesting", 100 + i) + for i in range(n_slots * 2) + ] + + tasks = [(make_streaming_completion, (prompt, seed)) for prompt, seed in prompts] + results = parallel_function_calls(tasks) + + for result in results: + assert isinstance(result, str) + assert len(result) > 0 + + +def test_concurrent_cache_consistency(): + global server + server.n_slots = 8 + server.n_ctx = 1024 + server.cache_prompt = True + server.start() + + shared_prompt_prefix = "In the beginning there was nothing but darkness and void. Then suddenly" + + tasks = [] + for i in range(32): + full_prompt = shared_prompt_prefix + f" variation {i % 4}" + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": full_prompt, + "seed": 42, + "n_predict": 16, + "cache_prompt": True, + }))) + + results = parallel_function_calls(tasks) + + for res in results: + assert res.status_code == 200 + assert "content" in res.body + assert type(res.body["content"]) == str + assert len(res.body["content"]) > 0 + + +@pytest.mark.parametrize("n_slots,n_sequences_per_slot", [ + (4, 2), + (8, 2), +]) +def test_parallel_sequence_processing(n_slots: int, n_sequences_per_slot: int): + global server + server.n_slots = n_slots + server.n_ctx = 512 + server.start() + + n_total_requests = n_slots * n_sequences_per_slot + prompts = [f"Tell me about topic number {i}" for i in range(n_total_requests)] + + tasks = [] + for i, prompt in enumerate(prompts): + tasks.append((server.make_request, ("POST", "/completion", { + "prompt": prompt, + "seed": 42 + i, + "n_predict": 20, + "temperature": 0.9, + }))) + + results = parallel_function_calls(tasks) + + unique_contents = set() + for res in results: + assert res.status_code == 200 + assert "content" in res.body + content = res.body["content"] + assert type(content) == str + assert len(content) > 0 + unique_contents.add(content) + + assert len(unique_contents) >= n_total_requests * 0.5