From ccabf5713d9a778c87067b7afb0ae5ab725c3c15 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 6 Nov 2025 15:55:44 +0200 Subject: [PATCH 1/2] cuda: set compute parameters via command line arguments --- common/common.cpp | 11 ++- common/common.h | 1 + examples/cvector-generator/pca.hpp | 2 +- ggml/include/ggml-cuda.h | 2 +- ggml/src/ggml-cuda.cu | 119 ++++++++++++++++++++++++----- ggml/src/ggml-cuda/common.cuh | 4 + include/llama.h | 1 + src/llama-cparams.h | 1 + src/llama-model-loader.cpp | 2 +- src/llama.cpp | 13 +++- 10 files changed, 127 insertions(+), 29 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index aef07f3e7..252b53e8c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1249,6 +1249,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "-cuda" || arg == "--cuda-params") { + CHECK_ARG + params.cuda_params = argv[i]; + return true; + } if (arg == "--cpu-moe" || arg == "-cmoe") { params.tensor_buft_overrides.push_back({strdup("\\.ffn_(up|down|gate)_exps\\.weight"), ggml_backend_cpu_buffer_type()}); return true; @@ -2076,6 +2081,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param options.push_back({ "*", " --no-context-shift", "disable context-shift." }); options.push_back({ "backend" }); options.push_back({ "*", " --rpc SERVERS", "comma separated list of RPC servers" }); + options.push_back({ "*", "-cuda, --cuda-params", "comma separate list of cuda parameters" }); if (llama_supports_mlock()) { options.push_back({ "*", " --mlock", "force system to keep model in RAM rather than swapping or compressing" }); @@ -2676,7 +2682,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { auto mparams = llama_model_params_from_gpt_params(params); llama_model * model = nullptr; - + if (!params.hf_repo.empty() && !params.hf_file.empty()) { model = llama_load_model_from_hf(params.hf_repo.c_str(), params.hf_file.c_str(), params.model.c_str(), params.hf_token.c_str(), mparams); } else if (!params.model_url.empty()) { @@ -2684,7 +2690,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { } else { model = llama_load_model_from_file(params.model.c_str(), mparams); } - + if (model == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); return iparams; @@ -2914,6 +2920,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.type_v = kv_cache_type_from_str(params.cache_type_v); if (!params.offload_policy.empty()) cparams.offload_policy = (void *)¶ms.offload_policy; + if (!params.cuda_params.empty()) cparams.cuda_params = (void *)params.cuda_params.data(); return cparams; } diff --git a/common/common.h b/common/common.h index 02043fc2f..4ad5908d9 100644 --- a/common/common.h +++ b/common/common.h @@ -198,6 +198,7 @@ struct gpt_params { std::string logits_file = ""; // file for saving *all* logits std::string rpc_servers = ""; // comma separated list of RPC servers + std::string cuda_params = ""; // comma separated list of cuda parameters key=value1,key2=value2 std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) diff --git a/examples/cvector-generator/pca.hpp b/examples/cvector-generator/pca.hpp index 85be07dd2..8d9204536 100644 --- a/examples/cvector-generator/pca.hpp +++ b/examples/cvector-generator/pca.hpp @@ -66,7 +66,7 @@ struct pca_model { pca_model(struct ggml_tensor * t_input) { #ifdef GGML_USE_CUDA fprintf(stderr, "%s: using CUDA backend\n", __func__); - backend = ggml_backend_cuda_init(0); // init device 0 + backend = ggml_backend_cuda_init(0, nullptr); // init device 0 if (!backend) { fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__); } diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index 71bb6dcf0..d17b4ce34 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -21,7 +21,7 @@ extern "C" { #define GGML_CUDA_MAX_DEVICES 16 // backend API -GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device); +GGML_API GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, const void * params); GGML_API GGML_CALL bool ggml_backend_is_cuda(ggml_backend_t backend); diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 88b5abe17..e548edce4 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -66,6 +66,7 @@ #include #include #include +#include #define IK_PRINT_TIMING 0 @@ -2420,7 +2421,8 @@ static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } - if (ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { + if (src1->ne[2] <= ctx.mmq_id_thresh*src0->ne[2] && + ggml_is_quantized(src0->type) && ggml_cuda_can_use_mmq_id(src0->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { ggml_cuda_mul_mat_q_id(ctx, src0, src1, ids, dst, nullptr, nullptr); return false; } @@ -2685,7 +2687,7 @@ static int ggml_cuda_moe_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_ten // My original hypothesis was that it is dependent on the total/active experts ratio, but from these 3 it // looks like it really depends just on the total number of experts. // TODO: verify with more models, or perhaps make the magic constant '32' to be defined via a compile time define. - if (src1->ne[2] <= 32*src0->ne[2] && + if (src1->ne[2] <= ctx.mmq_id_thresh*src0->ne[2] && ggml_is_quantized(src0_1->type) && src0_1->type == src0_2->type && src1->ne[1] == 1 && src1->ne[3] == 1 && ggml_cuda_can_use_mmq_id(src0_1->type, ggml_cuda_info().devices[ctx.device].cc, src1->ne[2])) { @@ -3060,6 +3062,8 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr; + auto fusion = ctx.fusion; + //printf("%4d %s(%s)\n", i, ggml_op_name(dst->op), dst->name); switch (dst->op) { case GGML_OP_ARGMAX: @@ -3084,7 +3088,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_dup(ctx, dst); break; case GGML_OP_ADD: - if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes && + if (fusion && i + 2 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_ADD && cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && ggml_is_contiguous(dst->src[0]) && @@ -3098,7 +3102,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_fused_add_add_rms_norm(ctx, dst, cgraph->nodes[i+1], cgraph->nodes[i+2]); i += 2; } - else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes && + else if (fusion && i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_FUSED_RMS_NORM && ggml_is_contiguous(dst->src[0]) && ggml_is_contiguous(dst->src[1]) && @@ -3155,7 +3159,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_relu(ctx, dst); break; case GGML_UNARY_OP_SIGMOID: - if (GGML_CUDA_FUSION && i + 5 < cgraph->n_nodes && + if (fusion && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_ARGSORT && @@ -3164,14 +3168,14 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg cuda_glm45moe_experts(ctx, cgraph->nodes[i+5], cgraph->nodes[i+4]); i += 5; } - else if (GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes && + else if (fusion && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && cgraph->nodes[i+3]->op == GGML_OP_GROUPED_TOPK && cgraph->nodes[i+4]->op == GGML_OP_GET_ROWS && ops_are_same_device(cgraph, i, i+4)) { cuda_bailingmoev2_experts(ctx, cgraph->nodes[i+4], cgraph->nodes[i+3]); i += 4; - } else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes && + } else if (fusion && i + 2 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ADD && ops_are_same_device(cgraph, i, i+2)) { ggml_cuda_op_biased_sigmoid(ctx, cgraph->nodes[i+2]); @@ -3242,7 +3246,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_rms_norm(ctx, dst); break; case GGML_OP_FUSED_RMS_NORM: - if (false && GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes && + if (false && fusion && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST && @@ -3250,7 +3254,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_fused_rms_rope_fast(ctx, cgraph->nodes[i+3], cgraph->nodes[i+4])) { i += 4; } - else if (false && GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes && + else if (false && fusion && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST && cgraph->nodes[i+2]->op == GGML_OP_RESHAPE && cgraph->nodes[i+3]->op == GGML_OP_FUSED_RMS_NORM && @@ -3258,7 +3262,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_fused_rms_rope_fast(ctx, cgraph->nodes[i+1], cgraph->nodes[i+4])) { i += 4; } - else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes && + else if (fusion && i + 2 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && cgraph->nodes[i+2]->op == GGML_OP_FUSED_RMS_NORM && dst->ne[2] == 1 && cgraph->nodes[i+2]->ne[2] == 1) { @@ -3310,7 +3314,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_diag_mask_inf(ctx, dst); break; case GGML_OP_SOFT_MAX: - if (GGML_CUDA_FUSION && i + 4 < cgraph->n_nodes && + if (fusion && i + 4 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_RESHAPE && cgraph->nodes[i+2]->op == GGML_OP_ARGSORT && cgraph->nodes[i+3]->op == GGML_OP_VIEW && @@ -3333,20 +3337,20 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_rope_back(ctx, dst); break; case GGML_OP_ROPE_FAST: - if (GGML_CUDA_FUSION && i + 3 < cgraph->n_nodes && + if (fusion && i + 3 < cgraph->n_nodes && (cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) && (cgraph->nodes[i+2]->op == GGML_OP_RESHAPE || cgraph->nodes[i+2]->op == GGML_OP_VIEW) && cgraph->nodes[i+3]->op == GGML_OP_ROPE_FAST && ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+3])) { i += 3; } - else if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes && + else if (fusion && i + 2 < cgraph->n_nodes && (cgraph->nodes[i+1]->op == GGML_OP_RESHAPE || cgraph->nodes[i+1]->op == GGML_OP_VIEW) && cgraph->nodes[i+2]->op == GGML_OP_ROPE_FAST && ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+2])) { i += 2; } - else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes && + else if (fusion && i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_ROPE_FAST && ggml_cuda_op_fused_rope_fast(ctx, dst, cgraph->nodes[i+1])) { i += 1; @@ -3374,7 +3378,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_pool2d(ctx, dst); break; case GGML_OP_SUM_ROWS: - if (GGML_CUDA_FUSION && i + 2 < cgraph->n_nodes && + if (fusion && i + 2 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_SCALE && cgraph->nodes[i+2]->op == GGML_OP_DIV && cgraph->nodes[i+1]->src[0] == dst && @@ -3383,7 +3387,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg ggml_cuda_op_sum_rows_div(ctx, cgraph->nodes[i+2]); i += 2; } - else if (GGML_CUDA_FUSION && i + 1 < cgraph->n_nodes && + else if (fusion && i + 1 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_DIV && cgraph->nodes[i+1]->src[1] == dst && cgraph->nodes[i+1]->src[0] == dst->src[0] && ops_are_same_device(cgraph, i, i+1)) { @@ -3394,7 +3398,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg } break; case GGML_OP_ARGSORT: - if (GGML_CUDA_FUSION && i + 5 < cgraph->n_nodes && + if (fusion && i + 5 < cgraph->n_nodes && cgraph->nodes[i+1]->op == GGML_OP_VIEW && cgraph->nodes[i+2]->op == GGML_OP_GET_ROWS && cgraph->nodes[i+3]->op == GGML_OP_RESHAPE && @@ -4462,7 +4466,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_buft(ggml_backend_t backend, gg } GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) { - constexpr int min_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD; + auto ctx = (const ggml_backend_cuda_context *)backend->context; + int min_batch_size = ctx->offload_batch_size; //originally: GGML_CUDA_MIN_BATCH_OFFLOAD; // Why do we want to do this? The heuristics that the batch must have more than min_batch_size tokens to be worth it // offloading the required model weights comes from dense models. For MoE models, the average number of tokens @@ -4575,7 +4580,65 @@ static ggml_guid_t ggml_backend_cuda_guid() { return &guid; } -GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) { +struct cuda_params { + int fusion = GGML_CUDA_FUSION; + int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD; + int mmq_id_thresh = 32; +}; + +static std::vector string_split(const std::string& str, const std::string& delimiter) { + std::vector parts; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + parts.push_back(str.substr(start, end - start)); + start = end + delimiter.length(); + end = str.find(delimiter, start); + } + + parts.push_back(str.substr(start)); + + return parts; +} + +template bool read_value(const std::string& val, T& result) { + std::istringstream str(val); + T tmp; str >> tmp; + if (!str.fail()) { + result = tmp; + return true; + } + return false; +} + +static cuda_params ggml_cuda_parse_params(const char * params_string) { + cuda_params params{}; + if (!params_string) return params; + auto values = string_split(std::string{params_string}, ","); + if (values.empty()) return params; + for (auto& value : values) { + auto parsed = string_split(value, "="); + bool is_good = false; + if (parsed.size() == 2) { + if (parsed[0] == "fusion") { + is_good = read_value(parsed[1], params.fusion); + } + else if (parsed[0] == "offload-batch-size") { + is_good = read_value(parsed[1], params.offload_batch_size); + } + else if (parsed[0] == "mmq-id-size") { + is_good = read_value(parsed[1], params.mmq_id_thresh); + } + } + if (!is_good) { + GGML_CUDA_LOG_WARN("%s: invalid parameter %s (%d) -> ignored\n", __func__, value.c_str(), (int)parsed.size()); + } + } + return params; +} + +GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device, [[maybe_unused]] const void * param_string) { if (device < 0 || device >= ggml_backend_cuda_get_device_count()) { GGML_CUDA_LOG_ERROR("%s: invalid device %d\n", __func__, device); return nullptr; @@ -4593,6 +4656,22 @@ GGML_CALL ggml_backend_t ggml_backend_cuda_init(int device) { /* .context = */ ctx }; + if (param_string) { + [[maybe_unused]] auto params = ggml_cuda_parse_params((const char *)param_string); + if (params.fusion != ctx->fusion) { + GGML_CUDA_LOG_INFO(" =========================== %s: setting fusion to %d\n", __func__, params.fusion); + ctx->fusion = params.fusion; + } + if (params.offload_batch_size != ctx->offload_batch_size) { + GGML_CUDA_LOG_INFO(" =========================== %s: setting offload_batch_size to %d\n", __func__, params.offload_batch_size); + ctx->offload_batch_size = params.offload_batch_size; + } + if (params.mmq_id_thresh != ctx->mmq_id_thresh) { + GGML_CUDA_LOG_INFO(" =========================== %s: setting mmq_id_thresh to %d\n", __func__, params.mmq_id_thresh); + ctx->mmq_id_thresh = params.mmq_id_thresh; + } + } + return cuda_backend; } @@ -4651,7 +4730,7 @@ GGML_CALL void ggml_backend_cuda_unregister_host_buffer(void * buffer) { // backend registry GGML_CALL static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * user_data) { - ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data); + ggml_backend_t cuda_backend = ggml_backend_cuda_init((int) (intptr_t) user_data, nullptr); return cuda_backend; GGML_UNUSED(params); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index b24f7fba7..4a55eed4c 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -837,6 +837,10 @@ struct ggml_backend_cuda_context { std::unique_ptr cuda_graph; + int fusion = GGML_CUDA_FUSION; + int offload_batch_size = GGML_CUDA_MIN_BATCH_OFFLOAD; + int mmq_id_thresh = 32; + explicit ggml_backend_cuda_context(int device); ~ggml_backend_cuda_context(); diff --git a/include/llama.h b/include/llama.h index 31d82949f..fe7bca4d5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -439,6 +439,7 @@ extern "C" { ggml_abort_callback abort_callback; void * abort_callback_data; void * offload_policy; + void * cuda_params; }; // model quantization parameters diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 3c32e404f..fcc107a35 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -45,4 +45,5 @@ struct llama_cparams { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + void * cuda_params; }; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 3f1ded18c..d7c68b336 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -907,7 +907,7 @@ bool llama_model_loader::load_all_data( for (int i = 0; i < ggml_backend_cuda_get_device_count(); ++i) { auto * cuda_buffer_type = ggml_backend_cuda_buffer_type(i); if (buffer_type == cuda_buffer_type) { - cuda_backend = ggml_backend_cuda_init(i); + cuda_backend = ggml_backend_cuda_init(i, nullptr); break; } } diff --git a/src/llama.cpp b/src/llama.cpp index b78adf90d..8f6f9e0bb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3844,6 +3844,7 @@ struct llama_context_params llama_context_default_params() { /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, /*.offload_policy =*/ nullptr, + /*.cuda_params =*/ nullptr, }; return result; @@ -4122,7 +4123,7 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; - + cparams.n_seq_max = std::max(1u, params.n_seq_max); cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; @@ -4143,6 +4144,7 @@ struct llama_context * llama_new_context_with_model( cparams.rope_cache = params.rope_cache; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; + cparams.cuda_params = params.cuda_params; cparams.pooling_type = params.pooling_type; @@ -4227,6 +4229,9 @@ struct llama_context * llama_new_context_with_model( LLAMA_LOG_INFO("%s: ser = %d, %g\n", __func__, cparams.min_experts, cparams.thresh_experts); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + if (cparams.cuda_params) { + LLAMA_LOG_INFO("%s: cuda_params = %s\n", __func__, (const char *)cparams.cuda_params); + } ctx->abort_callback = params.abort_callback; ctx->abort_callback_data = params.abort_callback_data; @@ -4266,7 +4271,7 @@ struct llama_context * llama_new_context_with_model( #elif defined(GGML_USE_CUDA) if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) { // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used - ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu); + ggml_backend_t backend = ggml_backend_cuda_init(model->main_gpu, cparams.cuda_params); if (backend == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, model->main_gpu); llama_free(ctx); @@ -4277,7 +4282,7 @@ struct llama_context * llama_new_context_with_model( } else { // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU for (int device = 0; device < ggml_backend_cuda_get_device_count(); ++device) { - ggml_backend_t backend = ggml_backend_cuda_init(device); + ggml_backend_t backend = ggml_backend_cuda_init(device, cparams.cuda_params); if (backend == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize CUDA%d backend\n", __func__, device); llama_free(ctx); @@ -4404,7 +4409,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends = std::move(backends); } - + ctx->backend_cpu = ggml_backend_cpu_init(); if (ctx->backend_cpu == nullptr) { LLAMA_LOG_ERROR("%s: failed to initialize CPU backend\n", __func__); From 06e9fcd4d88336a05f8a3f526f282ec8c3cd4adf Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 6 Nov 2025 18:08:03 +0200 Subject: [PATCH 2/2] Also llama-bench --- examples/llama-bench/llama-bench.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index afa46d2c6..b615ca919 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -256,6 +256,7 @@ struct cmd_params { std::vector embeddings; std::vector buft_overrides; ggml_numa_strategy numa; + std::string cuda_params; int reps; bool verbose; bool warmup; @@ -295,6 +296,7 @@ static const cmd_params cmd_params_defaults = { /* embeddings */ {false}, /* buft_overrides */ {}, /* numa */ GGML_NUMA_STRATEGY_DISABLED, + /* cuda_params */ {}, /* reps */ 5, /* verbose */ false, /* warmup */ true, @@ -344,6 +346,7 @@ static void print_usage(int /* argc */, char ** argv) { printf(" -v, --verbose (default: %s)\n", cmd_params_defaults.verbose ? "1" : "0"); printf(" -w, --warmup <0|1> (default: %s)\n", cmd_params_defaults.warmup ? "1" : "0"); printf(" -rtr, --run-time-repack <0|1> (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); + printf(" -cuda, --cuda-params (default: %s)\n", cmd_params_defaults.repack ? "1" : "0"); printf(" -mqkv, --merge-qkv (default: %s)\n", cmd_params_defaults.mqkv ? "1" : "0"); printf(" -thp, --transparent-huge-pages <0|1> (default: %s)\n", cmd_params_defaults.use_thp? "1" : "0"); printf(" -ot, --override-tensor pattern (default: none)\n"); @@ -736,6 +739,12 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { break; } params.repack = std::stoi(argv[i]); + } else if (arg == "-cuda" || arg == "--cuda-params") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.cuda_params = argv[i]; } else if (arg == "-mqkv" || arg == "--merge-qkv") { if (++i >= argc) { invalid_param = true; @@ -852,6 +861,7 @@ struct cmd_params_instance { int attn_max_batch; Ser ser; std::vector tensor_split; + std::string cuda_params; bool use_mmap; bool embeddings; bool repack = false; @@ -914,6 +924,7 @@ struct cmd_params_instance { cparams.min_experts = ser.first; cparams.thresh_experts = ser.second; cparams.embeddings = embeddings; + cparams.cuda_params = (void *)cuda_params.data(); return cparams; } @@ -965,6 +976,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .attn_max_b = */ amb, /* .ser = */ ser, /* .tensor_split = */ ts, + /* .cuda_params = */ params.cuda_params, /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, @@ -1003,6 +1015,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .attn_max_b = */ amb, /* .ser = */ ser, /* .tensor_split = */ ts, + /* .cuda_params = */ params.cuda_params, /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, @@ -1041,6 +1054,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .attn_max_b = */ amb, /* .ser = */ ser, /* .tensor_split = */ ts, + /* .cuda_params = */ params.cuda_params, /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, @@ -1079,6 +1093,7 @@ static std::vector get_cmd_params_instances(const cmd_param /* .attn_max_b = */ amb, /* .ser = */ ser, /* .tensor_split = */ ts, + /* .cuda_params = */ params.cuda_params, /* .use_mmap = */ mmp, /* .embeddings = */ embd, /* .repack = */ params.repack, @@ -1128,6 +1143,7 @@ struct test { int attn_max_batch; Ser ser; std::vector tensor_split; + std::string cuda_params; bool use_mmap; bool embeddings; bool repack = false; @@ -1166,6 +1182,7 @@ struct test { attn_max_batch = inst.attn_max_batch; ser = inst.ser; tensor_split = inst.tensor_split; + cuda_params = inst.cuda_params; use_mmap = inst.use_mmap; embeddings = inst.embeddings; repack = inst.repack;