From 6dc8c79d8157fc6fbc28db2e0110ef33f0ba67d6 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 14 Oct 2025 16:24:16 +0300 Subject: [PATCH 1/5] Adding Ling/Ring (a.k.a., Bailing-MoE2) --- src/llama-arch.cpp | 9 ++- src/llama-arch.h | 3 + src/llama-build-context.cpp | 137 ++++++++++++++++++++++++++++++++++++ src/llama-build-context.h | 2 + src/llama-hparams.cpp | 25 +++++++ src/llama-hparams.h | 39 ++++++---- src/llama-load-tensors.cpp | 75 ++++++++++++++++++++ src/llama-model.cpp | 34 +++++++++ src/llama-model.h | 2 + src/llama-vocab.cpp | 4 +- src/llama.cpp | 70 ++++++++++++++++++ 11 files changed, 381 insertions(+), 19 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 6d72f9595..6bb16a045 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -58,11 +58,12 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE, "granite" }, { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_COHERE2, "cohere2" }, - { LLM_ARCH_DOTS1, "dots1" }, - { LLM_ARCH_ERNIE4_5, "ernie4_5" }, - { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, + { LLM_ARCH_DOTS1, "dots1" }, + { LLM_ARCH_ERNIE4_5, "ernie4_5" }, + { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, + { LLM_ARCH_BAILINGMOE2, "bailingmoe2" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -103,6 +104,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, { LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" }, { LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" }, + { LLM_KV_EXPERT_GROUP_COUNT, "%s.expert_group_count" }, + { LLM_KV_EXPERT_GROUP_USED_COUNT, "%s.expert_group_used_count" }, { LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" }, { LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" }, { LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 65e036ca8..0069a8ee6 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -62,6 +62,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_OPENAI_MOE, + LLM_ARCH_BAILINGMOE2, LLM_ARCH_UNKNOWN, }; @@ -92,6 +93,8 @@ enum llm_kv { LLM_KV_EXPERT_COUNT, LLM_KV_EXPERT_USED_COUNT, LLM_KV_EXPERT_SHARED_COUNT, + LLM_KV_EXPERT_GROUP_COUNT, + LLM_KV_EXPERT_GROUP_USED_COUNT, LLM_KV_EXPERT_WEIGHTS_SCALE, LLM_KV_EXPERT_WEIGHTS_NORM, LLM_KV_EXPERT_GATING_FUNC, diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 8b3cd9a05..04a8a2992 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -8192,6 +8192,139 @@ ggml_cgraph * llm_build_context::build_openai_moe() { return gf; } +ggml_cgraph * llm_build_context::build_bailingmoe2() { + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + //auto * inp_attn = build_attn_inp_kv(); + ggml_tensor * KQ_mask = build_inp_KQ_mask(); + //const int64_t n_embd_head = hparams.n_embd_head_v; + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + + for (int il = 0; il < n_transformer_layers; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self_attention + { + cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); + //ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + ggml_tensor * Vcur = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + + Qcur = llm_build_norm(ctx0, Qcur, hparams, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = llm_build_norm(ctx0, Kcur, hparams, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, model.layers[il].wo, model.layers[il].bo, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il); + } + if (il == n_transformer_layers - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpSA); + cb(sa_out, "sa_out", il); + + // MoE branch + cur = llm_build_norm(ctx0, sa_out, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + if (static_cast(il) < hparams.n_layer_dense_lead) { + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + + ggml_tensor * moe_out = + llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + true, hparams.expert_weights_scale, + (llm_expert_gating_func_type) hparams.expert_gating_func, + cb, il, gf); + cb(moe_out, "ffn_moe_out", il); + + ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, sa_out); + + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, NULL, LLM_NORM_RMS, cb, -1); + + cb(cur, "result_norm", -1); + + // lm_head + cur = llm_build_lora_mm(lctx, ctx0, model.output, cur); + + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + return gf; +} + ggml_cgraph * llm_build_context::llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { llama_batch dummy; dummy.n_tokens = 0; @@ -8513,6 +8646,10 @@ ggml_cgraph * llm_build_context::llama_build_graph( { result = llm.build_openai_moe(); } break; + case LLM_ARCH_BAILINGMOE2: + { + result = llm.build_bailingmoe2(); + } break; default: GGML_ABORT("fatal error"); } diff --git a/src/llama-build-context.h b/src/llama-build-context.h index c5789ad9a..a1f0b8ae9 100644 --- a/src/llama-build-context.h +++ b/src/llama-build-context.h @@ -251,6 +251,8 @@ struct llm_build_context { ggml_cgraph * build_openai_moe(); + ggml_cgraph * build_bailingmoe2(); + // static ggml_tensor * llm_build_lora_mm(llama_context & lctx, ggml_context * ctx0, ggml_tensor * w, ggml_tensor * cur); diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 7f5c80f4e..d706c0915 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -894,6 +894,31 @@ void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_BAILINGMOE2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups); + ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 20: model.type = MODEL_16B_A1B; break; + case 21: model.type = MODEL_16B_A1B; break; + case 32: model.type = MODEL_100B_A6B; break; + case 33: model.type = MODEL_100B_A6B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_DOTS1: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index ea1a75628..4235714ae 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -24,6 +24,7 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; uint32_t n_layer; + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; uint32_t n_swa = 0; // sliding window attention (SWA) uint32_t n_swa_pattern = 1; // by default, all layers use non-sliding-window attention @@ -39,22 +40,30 @@ struct llama_hparams { std::array n_ff_arr; uint32_t n_layer_dense_lead = 0; - uint32_t n_lora_q = 0; - uint32_t n_lora_kv = 0; - uint32_t n_ff_exp = 0; - uint32_t n_ff_shexp = 0; - uint32_t n_expert_shared = 0; - float expert_weights_scale = 0.0; - bool expert_weights_norm = false; - uint32_t expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; + uint32_t n_lora_q = 0; + uint32_t n_lora_kv = 0; + uint32_t n_ff_exp = 0; + uint32_t n_ff_shexp = 0; + uint32_t n_expert_shared = 0; + uint32_t n_norm_groups = 0; + uint32_t n_expert_groups = 0; + uint32_t n_group_used = 0; + uint32_t n_group_experts = 0; + + float expert_group_scale = 0.05f; + float expert_weights_scale = 0.0f; + bool expert_weights_norm = false; + uint32_t expert_gating_func = LLM_EXPERT_GATING_FUNC_SOFTMAX; + uint32_t moe_every_n_layers = 0; uint32_t nextn_predict_layers = 0; float f_norm_eps; float f_norm_rms_eps; + float f_norm_group_eps; - float f_attn_logit_softcapping = 50.0f; + float f_attn_logit_softcapping = 50.0f; float f_router_logit_softcapping = 30.0f; - float f_final_logit_softcapping = 30.0f; + float f_final_logit_softcapping = 30.0f; float rope_attn_factor = 1.0f; float rope_freq_base_train; @@ -62,12 +71,12 @@ struct llama_hparams { float rope_freq_scale_train; float rope_freq_scale_train_swa; uint32_t n_ctx_orig_yarn; - float rope_yarn_log_mul; + float rope_yarn_log_mul = 0.0f; - float yarn_ext_factor = -1.0f; - float yarn_attn_factor = 1.0f; - float yarn_beta_fast = 32.0f; - float yarn_beta_slow = 1.0f; + float yarn_ext_factor = -1.0f; + float yarn_attn_factor = 1.0f; + float yarn_beta_fast = 32.0f; + float yarn_beta_slow = 1.0f; std::array rope_sections; diff --git a/src/llama-load-tensors.cpp b/src/llama-load-tensors.cpp index c13d95fc3..85190c57e 100644 --- a/src/llama-load-tensors.cpp +++ b/src/llama-load-tensors.cpp @@ -124,6 +124,8 @@ struct create_tensors_helper : public create_tensors_helper_interface { bool create_openai_moe_tensors(const LLM_TN & tn); + bool create_bailingmoe2_tensors(const LLM_TN & tn); + llama_model_loader & ml; llama_model & model; @@ -2205,6 +2207,77 @@ bool create_tensors_helper::create_dots1_tensors(const LLM_TN & tn) { return use_mmap_buffer; } +bool create_tensors_helper::create_bailingmoe2_tensors(const LLM_TN & tn) { + LOADING_PRELUDE + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + model.tok_embd = create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + model.output_norm = create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + model.output = create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= llama_model_loader::TENSOR_SKIP; + } + + layer.attn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + + layer.wqkv = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags); + layer.wo = create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_q_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + if (static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; + + layer.ffn_gate_inp = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, + llama_model_loader::TENSOR_NOT_REQUIRED | flags); + + layer.ffn_gate_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } else { // Dense layers + layer.ffn_gate = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.embed_tokens = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, + llama_model_loader::TENSOR_NOT_REQUIRED | flags); + layer.nextn.enorm = create_tensor(ctx_layer, tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(ctx_layer, tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_head = create_tensor(ctx_split, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, llama_model_loader::TENSOR_NOT_REQUIRED | flags); + layer.nextn.shared_head_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, llama_model_loader::TENSOR_NOT_REQUIRED | flags); + layer.layer_out_norm = create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags); + } + } + return use_mmap_buffer; +} + bool create_tensors_helper::create_ernie45_tensors(const LLM_TN & tn) { LOADING_PRELUDE @@ -2460,6 +2533,8 @@ bool create_tensors_helper::create_tensors() { use_mmap_buffer = create_hunyuan_tensors(tn); break; case LLM_ARCH_OPENAI_MOE: use_mmap_buffer = create_openai_moe_tensors(tn); break; + case LLM_ARCH_BAILINGMOE2: + use_mmap_buffer = create_bailingmoe2_tensors(tn); break; default: throw std::runtime_error("unknown architecture"); } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 57c7a6a32..d2b92e75d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1157,6 +1157,38 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, }, }, + { + LLM_ARCH_BAILINGMOE2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, + { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, + { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, + { LLM_TENSOR_NEXTN_HNORM, "blk.%d.nextn.hnorm" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "blk.%d.nextn.shared_head_head" }, + { LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "blk.%d.nextn.shared_head_norm" }, + { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1368,6 +1400,8 @@ const char * llama_model_type_name(e_model type) { case MODEL_80B_A13B: return "80B.A13B"; case MODEL_21B_A3B: return "21B.A3B"; case MODEL_300B_A47B: return "300B.A47B"; + case MODEL_16B_A1B: return "16B.A1B"; + case MODEL_100B_A6B: return "100B.A6B"; default: return "?B"; } } diff --git a/src/llama-model.h b/src/llama-model.h index 6b6fb47f5..8446be988 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -84,6 +84,8 @@ enum e_model { MODEL_17B_128E, MODEL_80B_A13B, MODEL_300B_A47B, // Ernie MoE big + MODEL_16B_A1B, + MODEL_100B_A6B, }; struct llama_layer_nextn { diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 372407a0e..9f4f8b084 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1961,7 +1961,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; clean_spaces = false; } else if ( - tokenizer_pre == "bailingmoe") { + tokenizer_pre == "bailingmoe" || + tokenizer_pre == "bailingmoe2" || + tokenizer_pre == "llada-moe") { pre_type = LLAMA_VOCAB_PRE_TYPE_BAILINGMOE; clean_spaces = false; } else if ( diff --git a/src/llama.cpp b/src/llama.cpp index 53ebb0751..5dc11e477 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -194,6 +194,9 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_GROK_2, + LLM_CHAT_TEMPLATE_BAILING, + LLM_CHAT_TEMPLATE_BAILING_THINK, + LLM_CHAT_TEMPLATE_BAILING2, LLM_CHAT_TEMPLATE_UNKNOWN, }; @@ -236,6 +239,10 @@ static const std::map LLM_CHAT_TEMPLATES = { { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "bitnet", LLM_CHAT_TEMPLATE_BITNET }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, + { "bailing", LLM_CHAT_TEMPLATE_BAILING }, + { "bailing-think", LLM_CHAT_TEMPLATE_BAILING_THINK }, + { "bailing2", LLM_CHAT_TEMPLATE_BAILING2 }, + }; // @@ -1200,6 +1207,19 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); } + if (model.arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llm_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); + } + vocab.print_info(); } @@ -4444,6 +4464,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: case LLM_ARCH_OPENAI_MOE: + case LLM_ARCH_BAILINGMOE2: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -6255,6 +6276,12 @@ static llm_chat_template llama_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { return LLM_CHAT_TEMPLATE_MEGREZ; + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("'HUMAN'")) { + return LLM_CHAT_TEMPLATE_BAILING; + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("\"HUMAN\"") && tmpl_contains("")) { + return LLM_CHAT_TEMPLATE_BAILING_THINK; + } else if (tmpl_contains("ASSISTANT") && tmpl_contains("HUMAN") && tmpl_contains("<|role_end|>")) { + return LLM_CHAT_TEMPLATE_BAILING2; } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) { return LLM_CHAT_TEMPLATE_LLAMA4; } else if (tmpl_contains("<|endofuserprompt|>")) { @@ -6657,6 +6684,49 @@ static int32_t llama_chat_apply_template_internal( if (add_ass) { ss << "<|role_start|>assistant<|role_end|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING || tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) { + // Bailing (Ling/Ring) template + for (auto message : chat) { + std::string role(message->role); + + if (role == "user") { + role = "HUMAN"; + } else { + std::transform(role.begin(), role.end(), role.begin(), ::toupper); + } + + ss << "" << role << "" << message->content; + } + + if (add_ass) { + ss << "ASSISTANT"; + + if (tmpl == LLM_CHAT_TEMPLATE_BAILING_THINK) { + ss << ""; + } + } + } else if (tmpl == LLM_CHAT_TEMPLATE_BAILING2) { + // Bailing2 (Ling 2.0) template + bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; + + if (!has_system) { + ss << "SYSTEMdetailed thinking off<|role_end|>"; + } + + for (auto message : chat) { + std::string role(message->role); + + if (role == "user") { + role = "HUMAN"; + } else { + std::transform(role.begin(), role.end(), role.begin(), ::toupper); + } + + ss << "" << role << "" << message->content << "<|role_end|>"; + } + if (add_ass) { + ss << "ASSISTANT"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_LLAMA4) { // Llama 4 for (auto message : chat) { From 8f014df83fe517748d93e972ce8203b81b4a418f Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 14 Oct 2025 17:41:18 +0300 Subject: [PATCH 2/5] Add expert group selection (not working, so turned off) --- ggml/include/ggml.h | 31 ++++++ ggml/src/ggml.c | 215 +++++++++++++++++++++++++++++++++--- src/llama-build-context.cpp | 37 +++++++ 3 files changed, 270 insertions(+), 13 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 562e6f584..a5fda8a4e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -630,6 +630,7 @@ extern "C" { GGML_OP_TRANSPOSE, GGML_OP_GET_ROWS, GGML_OP_GET_ROWS_BACK, + GGML_OP_SET_ROWS, GGML_OP_DIAG, GGML_OP_DIAG_MASK_INF, GGML_OP_DIAG_MASK_ZERO, @@ -1559,6 +1560,19 @@ extern "C" { struct ggml_tensor * a, float s); + // x = s * a + b + GGML_API struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + + GGML_API struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b); + GGML_API struct ggml_tensor * ggml_softcap( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1781,6 +1795,23 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * c); + // a TD [n_embd, ne1, ne2, ne3] + // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3 + // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1) + // + // undefined behavior if destination rows overlap + // + // broadcast: + // ne2 % ne11 == 0 + // ne3 % ne12 == 0 + // + // return view(a) + GGML_API struct ggml_tensor * ggml_set_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, // destination + struct ggml_tensor * b, // source + struct ggml_tensor * c); // row indices + GGML_API struct ggml_tensor * ggml_diag( struct ggml_context * ctx, struct ggml_tensor * a); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 81077f5fb..9b730d09a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3206,6 +3206,53 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) { +#if defined(GGML_USE_ACCELERATE) + vDSP_vsmsa(x, 1, &s, &b, y, 1, n); +#elif defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + // scalar ; TODO: Write SVE code + for (int i = 0; i < n; ++i) { + y[i] = x[i]*s + b; + } + #elif defined(__riscv_v_intrinsic) + for (int i = 0, avl; i < n; i += avl) { + avl = __riscv_vsetvl_e32m8(n - i); + vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); + vfloat32m8_t vb = __riscv_vfmv_v_f_f32m8(b, avl); + vfloat32m8_t ny = __riscv_vfmadd_vf_f32m8(ax, s, vb, avl); + __riscv_vse32_v_f32m8(&y[i], ny, avl); + } + #else + const int np = (n & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC vs = GGML_F32_VEC_SET1(s); + GGML_F32_VEC vb = GGML_F32_VEC_SET1(b); + + GGML_F32_VEC ay[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); + ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs); + + GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = x[i]*s + b; + } + #endif +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = x[i]*s + b; + } +#endif +} + inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { #if defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -4185,6 +4232,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "TRANSPOSE", "GET_ROWS", "GET_ROWS_BACK", + "SET_ROWS", "DIAG", "DIAG_MASK_INF", "DIAG_MASK_ZERO", @@ -4239,7 +4287,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); +static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -4287,6 +4335,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "transpose(x)", "get_rows(x)", "get_rows_back(x)", + "set_rows(x)", "diag(x)", "diag_mask_inf(x)", "diag_mask_zero(x)", @@ -4341,7 +4390,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)," }; -static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86"); +static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -7544,6 +7593,7 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_context * ctx, struct ggml_tensor * a, float s, + float b, bool inplace) { GGML_ASSERT(ggml_is_padded_1d(a)); @@ -7555,7 +7605,8 @@ static struct ggml_tensor * ggml_scale_impl( struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - ggml_set_op_params(result, &s, sizeof(s)); + float params[2] = {s, b}; + ggml_set_op_params(result, ¶ms, sizeof(params)); result->op = GGML_OP_SCALE; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7568,14 +7619,30 @@ struct ggml_tensor * ggml_scale( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, false); + return ggml_scale_impl(ctx, a, s, 0.f, false); } struct ggml_tensor * ggml_scale_inplace( struct ggml_context * ctx, struct ggml_tensor * a, float s) { - return ggml_scale_impl(ctx, a, s, true); + return ggml_scale_impl(ctx, a, s, 0.f, true); +} + +struct ggml_tensor * ggml_scale_bias( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, false); +} + +struct ggml_tensor * ggml_scale_bias_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + float s, + float b) { + return ggml_scale_impl(ctx, a, s, b, true); } // ggml_softcap @@ -8294,6 +8361,36 @@ struct ggml_tensor * ggml_get_rows_back( return result; } +// ggml_set_rows + +struct ggml_tensor * ggml_set_rows( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c) { + GGML_ASSERT(a->ne[0] == b->ne[0]); + GGML_ASSERT(a->ne[2] == b->ne[2]); + GGML_ASSERT(a->ne[3] == b->ne[3]); + GGML_ASSERT(b->ne[1] == c->ne[0]); + GGML_ASSERT(b->ne[2] % c->ne[1] == 0); + GGML_ASSERT(b->ne[3] % c->ne[2] == 0); + GGML_ASSERT(c->ne[3] == 1); + GGML_ASSERT(b->type == GGML_TYPE_F32); + GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32); + + GGML_ASSERT(ggml_is_contiguous_rows(a)); + GGML_ASSERT(ggml_is_contiguous_rows(b)); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_SET_ROWS; + result->src[0] = b; + result->src[1] = c; + result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931) + + return result; +} + // ggml_diag struct ggml_tensor * ggml_diag( @@ -16594,8 +16691,9 @@ static void ggml_compute_forward_scale_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); // scale factor - float v; - memcpy(&v, dst->op_params, sizeof(float)); + float s, b; + memcpy(&s, (const float *)dst->op_params + 0, sizeof(float)); + memcpy(&b, (const float *)dst->op_params + 1, sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -16614,12 +16712,21 @@ static void ggml_compute_forward_scale_f32( const size_t nb1 = dst->nb[1]; - for (int i1 = ir0; i1 < ir1; i1++) { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + if (b == 0.0f) { + for (int i1 = ir0; i1 < ir1; i1++) { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float)); + } + ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s); + } + } else { + for (int i1 = ir0; i1 < ir1; i1++) { + ggml_vec_mad1_f32(nc, + (float *) ((char *) dst->data + i1*nb1), + (float *) ((char *) src0->data + i1*nb1), + s, b); } - ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v); } } @@ -17455,6 +17562,78 @@ static void ggml_compute_forward_get_rows_back( //} } +static void ggml_compute_forward_set_rows_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ne01; + + assert(ne0 == nc); + assert(ne2 == ne02); + assert(ne3 == ne03); + assert(src0->type == GGML_TYPE_F32); + assert(ne02 % ne11 == 0); + assert(ne03 % ne12 == 0); + + const int ith = params->ith; + const int nth = params->nth; + + // rows per thread + const int64_t dr = (nr + nth - 1)/nth; + + // row range for this thread + const int64_t ir0 = dr*ith; + const int64_t ir1 = MIN(ir0 + dr, nr); + + ggml_from_float_t const from_float = type_traits[dst->type].from_float; + + for (int64_t i03 = 0; i03 < ne03; ++i03) { + for (int64_t i02 = 0; i02 < ne02; ++i02) { + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i03%ne12; + const int64_t i11 = i02%ne11; + const int64_t i10 = i; + + const int64_t i1 = *(int64_t*) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + GGML_ASSERT(i1 >= 0 && i1 < ne1); + + from_float((const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03), + ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc); + } + } + } +} + +static void ggml_compute_forward_set_rows( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + if (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32) { + ggml_compute_forward_set_rows_f32(params, dst); + } else { + GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type)); + } + } break; + default: + { + GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type)); + } + } +} + // ggml_compute_forward_diag static void ggml_compute_forward_diag_f32( @@ -22208,10 +22387,15 @@ static int ggml_compute_forward(struct ggml_compute_params * params, struct ggml { ggml_compute_forward_get_rows(params, tensor); } break; + case GGML_OP_GET_ROWS_BACK: { ggml_compute_forward_get_rows_back(params, tensor); } break; + case GGML_OP_SET_ROWS: + { + ggml_compute_forward_set_rows(params, tensor); + } break; case GGML_OP_DIAG: { ggml_compute_forward_diag(params, tensor); @@ -22973,7 +23157,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, - ggml_scale_impl(ctx, tensor->grad, s, false), + ggml_scale_impl(ctx, tensor->grad, s, 0.0f, false), zero_table); } } break; @@ -23144,6 +23328,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // noop } } break; + case GGML_OP_SET_ROWS: + { + GGML_ABORT("fatal error"); // TODO: not implemented + } case GGML_OP_GET_ROWS_BACK: { GGML_ABORT("fatal error"); // TODO: not implemented @@ -24010,6 +24198,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_GET_ROWS: + case GGML_OP_SET_ROWS: { // FIXME: get_rows can use additional threads, but the cost of launching additional threads // decreases performance with GPU offloading diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 04a8a2992..819945fc4 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -820,6 +820,38 @@ llm_expert_gating_func_type gating_op, selection_probs = logits; } + if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2) { + auto& hparams = lctx.model.hparams; + const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups; + + // organize experts into n_expert_groups + ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups] +#if 0 + ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups] + group_scores = ggml_get_rows(ctx0, ggml_reshape_3d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups] + + // get top n_group_used expert groups + group_scores = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1] +#else + // Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead + ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups] + group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 1, n_expert_groups] + + // get top n_group_used expert groups + group_scores = ggml_transpose(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2])); // [n_expert_groups, 1] +#endif + ggml_tensor * expert_groups = ggml_top_k(ctx, ggml_cont(ctx, group_scores), hparams.n_group_used); // [n_group_used, 1] + cb(expert_groups->src[0], "ffn_moe_group_argsort", il); + cb(expert_groups, "ffn_moe_group_topk", il); + + // mask out the other groups + selection_probs = ggml_get_rows(ctx, selection_groups, expert_groups); // [n_tokens * n_exp_per_group, n_group_used] + selection_probs = ggml_set_rows(ctx, ggml_scale_bias(ctx, selection_groups, 0.0f, -INFINITY), selection_probs, expert_groups); // [n_tokens * n_exp_per_group, n_expert_groups] + selection_probs = ggml_view_2d(ctx, selection_probs, n_tokens, n_expert, n_tokens * sizeof(float), 0); // [n_tokens, n_expert] + selection_probs = ggml_cont(ctx, ggml_transpose(ctx, selection_probs)); // [n_expert, n_tokens] + cb(selection_probs, "ffn_moe_probs_masked", il); + } + // select experts ggml_tensor * selected_experts = ggml_top_k_thresh(ctx, selection_probs, n_expert_used, lctx.cparams.min_experts, lctx.cparams.thresh_experts); // [n_expert_used, n_tokens] @@ -846,6 +878,11 @@ llm_expert_gating_func_type gating_op, ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens] cb(weights_sum, "ffn_moe_weights_sum", il); + if (lctx.model.arch == LLM_ARCH_BAILINGMOE2) { + weights_sum = ggml_scale_bias(ctx, weights_sum, 1.0, 1e-20); + cb(weights_sum, "ffn_moe_weights_sum_biased", il); + } + weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens] cb(weights, "ffn_moe_weights_norm", il); From 2bc3fd32e72b4ce4a50dec4fc2368c02ab45baf5 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 14 Oct 2025 19:03:23 +0300 Subject: [PATCH 3/5] BailingMoE2 conversion --- convert_hf_to_gguf.py | 100 +++++++++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 33 +++++++++++ gguf-py/gguf/gguf_writer.py | 6 ++ gguf-py/gguf/tensor_mapping.py | 7 +++ 4 files changed, 146 insertions(+) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7a5ae01e8..6b21c05fd 100644 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -654,6 +654,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890": # ref: https://huggingface.co/moonshotai/Kimi-K2-Base res = "kimi-k2" + if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206": + # ref: https://huggingface.co/inclusionAI/Ling-mini-base-2.0 + res = "bailingmoe2" if res is None: logger.warning("\n") @@ -4461,6 +4464,103 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter name = name.removeprefix("transformer.") return [(self.map_tensor_name(name), data_torch)] +@Model.register("BailingMoeV2ForCausalLM") +class BailingMoeV2Model(Model): + model_arch = gguf.MODEL_ARCH.BAILINGMOE2 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if nextn_layers := self.hparams.get("num_nextn_predict_layers", 0): + self.block_count = self.hparams["num_hidden_layers"] + nextn_layers + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + if (rope_dim := hparams.get("head_dim")) is None: + rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] + + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + rope_scaling = self.hparams.get("rope_scaling") or {} + if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN) + self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"]) + self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"]) + else: + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + self.gguf_writer.add_vocab_size(hparams["vocab_size"]) + self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) + self.gguf_writer.add_expert_shared_feed_forward_length(hparams["moe_shared_expert_intermediate_size"]) + self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) + self.gguf_writer.add_expert_count(hparams["num_experts"]) + self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) + self.gguf_writer.add_expert_group_count(hparams["n_group"]) + self.gguf_writer.add_expert_group_used_count(hparams["topk_group"]) + self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + + if hparams["score_function"] == "sigmoid": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) + elif hparams["score_function"] == "softmax": + self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX) + else: + raise ValueError(f"Unsupported score_function value: {hparams['score_function']}") + + if (nextn_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(nextn_layers) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "mlp.experts" in name: + n_experts = self.hparams["num_experts"] + assert bid is not None + + tensors: list[tuple[str, Tensor]] = [] + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + # merge the experts into a single 3d tensor + for w_name in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight" + + new_name = self.map_tensor_name(merged_name) + + tensors.append((new_name, data_torch)) + + return tensors + + if name.endswith(".expert_bias"): + name = name.replace(".expert_bias", ".expert_bias.bias") + + return [(self.map_tensor_name(name), data_torch)] + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 30dc3dc6c..59bcba526 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -88,6 +88,8 @@ class LLM: EXPERT_COUNT = "{arch}.expert_count" EXPERT_USED_COUNT = "{arch}.expert_used_count" EXPERT_SHARED_COUNT = "{arch}.expert_shared_count" + EXPERT_GROUP_COUNT = "{arch}.expert_group_count" + EXPERT_GROUP_USED_COUNT = "{arch}.expert_group_used_count" EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale" EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm" EXPERT_GATING_FUNC = "{arch}.expert_gating_func" @@ -245,6 +247,7 @@ class MODEL_ARCH(IntEnum): DOTS1 = auto() ERNIE4_5 = auto() ERNIE4_5_MOE = auto() + BAILINGMOE2 = auto() class MODEL_TENSOR(IntEnum): TOKEN_EMBD = auto() @@ -390,6 +393,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ERNIE4_5: "ernie4_5", MODEL_ARCH.ERNIE4_5_MOE: "ernie4_5-moe", + MODEL_ARCH.BAILINGMOE2: "bailingmoe2", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -1291,6 +1295,35 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_UP_SHEXP, MODEL_TENSOR.FFN_EXP_PROBS_B, ], + MODEL_ARCH.BAILINGMOE2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + MODEL_TENSOR.LAYER_OUT_NORM, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6a4092a63..8ad3d065a 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -671,6 +671,12 @@ def add_expert_used_count(self, count: int) -> None: def add_expert_shared_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count) + def add_expert_group_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_GROUP_COUNT.format(arch=self.arch), count) + + def add_expert_group_used_count(self, count: int) -> None: + self.add_uint32(Keys.LLM.EXPERT_GROUP_USED_COUNT.format(arch=self.arch), count) + def add_expert_weights_scale(self, value: float) -> None: self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index e8eb4b5e8..d9af8e32c 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -27,6 +27,7 @@ class TensorNameMap: "model.layers.{bid}.pre_attn_norm", # grok-2 "embedding.word_embeddings", # chatglm "transformer.token_embeddings", # openelm + "model.word_embeddings", # bailingmoe "shared", # t5 ), @@ -129,6 +130,7 @@ class TensorNameMap: "h.{bid}.self_attention.query_key_value", # bloom "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon "model.layers.{bid}.self_attn.query_key_value", # persimmon + "model.layers.{bid}.attention.query_key_value", # bailingmoe2 "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert @@ -187,6 +189,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.out_proj", # gpt-j "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "model.layers.{bid}.self_attn.dense", # persimmon + "model.layers.{bid}.attention.dense", # bailingmoe2 "h.{bid}.attn.c_proj", # gpt2 "transformer.h.{bid}.mixer.out_proj", # phi2 "model.layers.layers.{bid}.self_attn.o_proj", # plamo @@ -263,6 +266,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_EXP_PROBS_B: ( "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1 "model.layers.{bid}.mlp.moe_statics.e_score_correction", # ernie4.5-moe + "model.layers.{bid}.mlp.gate.expert_bias", # bailingmoe2 ), # Feed-forward up @@ -382,6 +386,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.q_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 "transformer.layers.{bid}.attn.q_norm", # openelm + "model.layers.{bid}.attention.query_layernorm", # bailingmoe2 ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -391,6 +396,7 @@ class TensorNameMap: "transformer.blocks.{bid}.attn.k_ln", # sea-lion "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 "transformer.layers.{bid}.attn.k_norm", # openelm + "model.layers.{bid}.attention.key_layernorm", # bailingmoe2 ), MODEL_TENSOR.ROPE_FREQS: ( @@ -403,6 +409,7 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "model.layers.{bid}.final_layernorm", # bailingmoe2 ), MODEL_TENSOR.SSM_IN: ( From 7705c5edca0441d5545da7a54f97442ba547feb9 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 15 Oct 2025 09:36:14 +0300 Subject: [PATCH 4/5] WIP --- ggml/src/ggml-cuda.cu | 12 ++ ggml/src/ggml-cuda/scale.cu | 28 ++-- ggml/src/ggml-cuda/set-rows.cu | 277 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/set-rows.cuh | 7 + ggml/src/ggml.c | 43 +++-- src/llama-build-context.cpp | 2 +- 6 files changed, 346 insertions(+), 23 deletions(-) create mode 100644 ggml/src/ggml-cuda/set-rows.cu create mode 100644 ggml/src/ggml-cuda/set-rows.cuh diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 4924c3bac..da5df81a5 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -44,6 +44,7 @@ #include "ggml-cuda/topk-moe.cuh" #include "ggml-cuda/conv2d.cuh" #include "ggml-cuda/conv2d-dw.cuh" +#include "ggml-cuda/set-rows.cuh" #include #include @@ -3111,6 +3112,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GET_ROWS: ggml_cuda_op_get_rows(ctx, dst); break; + case GGML_OP_SET_ROWS: + ggml_cuda_op_set_rows(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -4204,6 +4208,14 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return false; } } break; + case GGML_OP_SET_ROWS: + { + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || + op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || + op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) && + op->src[0]->type == GGML_TYPE_F32 && + (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 1405e066e..ac1384699 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -1,19 +1,21 @@ #include "scale.cuh" -static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +#define MAX_GRIDDIM_X 0x7FFFFFFF - if (i >= k) { - return; - } +static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; + int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; - dst[i] = scale * x[i]; + for (int64_t i = tid; i < nelements; i += stride) { + dst[i] = scale * x[i] + bias; + } } -static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<>>(x, dst, scale, k); -} +static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { + const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; + // Whehn will we be scaling tensors with more than 2^39 elements? + //scale_f32<<>>(x, dst, scale, bias, nelements); + scale_f32<<>>(x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; @@ -25,7 +27,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { GGML_ASSERT( dst->type == GGML_TYPE_F32); float scale; - memcpy(&scale, dst->op_params, sizeof(float)); + float bias; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&bias, (float *) dst->op_params + 1, sizeof(float)); - scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream); + scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream); } diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu new file mode 100644 index 000000000..bb81c0416 --- /dev/null +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -0,0 +1,277 @@ +#include "set-rows.cuh" +#include "cpy-utils.cuh" +#include "convert.cuh" + +typedef void (*set_rows_kernel_t)(const char * src, char * dst); + +// Generic quantized set_rows kernel template +template +static __global__ void k_set_rows_quant( + const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s10, const int64_t s11, const int64_t s12, + const int64_t s1, const int64_t s2, const int64_t s3) { + + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; + + if (i >= ne_total) { + return; + } + + const int64_t i_base = i * qk; + const int64_t i03 = i_base / (ne00 * ne01 * ne02); + const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + + const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_type * dst_row_ptr = dst + (dst_row*s1 + i02*s2 + i03*s3) / sizeof(block_type); + + const float * src_block = src0_row + i00; + block_type * dst_block = dst_row_ptr + i00 / qk; + + quantize_func(src_block, dst_block); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); +} + +// Template dispatch function for quantized set_rows +template +static void set_rows_cuda_quant( + const float * src0_d, const idx_t * src1_d, block_type * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + + GGML_ASSERT(ne00 % qk == 0); + const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); + const dim3 grid_size(num_blocks); + + const int64_t s01 = nb01/sizeof(float); + const int64_t s02 = nb02/sizeof(float); + const int64_t s03 = nb03/sizeof(float); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); + const int64_t s1 = nb1; + const int64_t s2 = nb2; + const int64_t s3 = nb3; + + if (ne_total > 0) { + k_set_rows_quant<<>>( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + s01, s02, s03, + s10, s11, s12, + s1, s2, s3); + } +} + +template +static __global__ void k_set_rows( + const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t s10, const int64_t s11, const int64_t s12, + const int64_t s1, const int64_t s2, const int64_t s3) { + + const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; + + if (i >= ne_total) { + return; + } + + const int64_t i03 = i / (ne00 * ne01 * ne02); + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00; + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; + + const int64_t i12 = i03 % ne12; + const int64_t i11 = i02 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + + const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; + dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; + + dst_row_ptr[i00] = ggml_cuda_cast(src0_row[i00]); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); +} + +template +static void set_rows_cuda( + const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const size_t nb01, const size_t nb02, const size_t nb03, + const size_t nb10, const size_t nb11, const size_t nb12, + const size_t nb1, const size_t nb2, const size_t nb3, + cudaStream_t stream) { + + const int64_t ne_total = ne00 * ne01 * ne02 * ne03; + const int num_blocks = (ne_total + CUDA_SET_ROWS_BLOCK_SIZE - 1) / CUDA_SET_ROWS_BLOCK_SIZE; + const dim3 block_size(CUDA_SET_ROWS_BLOCK_SIZE); + const dim3 grid_size(num_blocks); + + + const int64_t s01 = nb01/sizeof(src_t); + const int64_t s02 = nb02/sizeof(src_t); + const int64_t s03 = nb03/sizeof(src_t); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); + const int64_t s1 = nb1/sizeof(dst_t); + const int64_t s2 = nb2/sizeof(dst_t); + const int64_t s3 = nb3/sizeof(dst_t); + + if (ne_total > 0) { + k_set_rows<<>>( + src0_d, src1_d, dst_d, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + s01, s02, s03, + s10, s11, s12, + s1, s2, s3); + } +} + +template +static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const src_t * src0_d = (const src_t *)src0->data; + const idx_t * src1_d = (const idx_t *)src1->data; + + GGML_TENSOR_BINARY_OP_LOCALS + + cudaStream_t stream = ctx.stream(); + + + if (dst->type == GGML_TYPE_F32) { + set_rows_cuda( + src0_d, src1_d, (float*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_F16) { + set_rows_cuda( + src0_d, src1_d, (half*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_BF16) { + set_rows_cuda( + src0_d, src1_d, (nv_bfloat16*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q4_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q4_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q4_1) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q4_1*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q5_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q5_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q5_1) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q5_1*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_Q8_0) { + set_rows_cuda_quant( + src0_d, src1_d, (block_q8_0*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else if (dst->type == GGML_TYPE_IQ4_NL) { + set_rows_cuda_quant( + src0_d, src1_d, (block_iq4_nl*)dst->data, + ne00, ne01, ne02, ne03, + ne10, ne11, ne12, ne13, + nb01, nb02, nb03, + nb10, nb11, nb12, + nb1, nb2, nb3, + stream + ); + } else { + GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); + } +} + + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32); + + if (src1->type == GGML_TYPE_I64) { + set_rows_cuda(ctx, src0, src1, dst); + } else { + set_rows_cuda(ctx, src0, src1, dst); + } +} diff --git a/ggml/src/ggml-cuda/set-rows.cuh b/ggml/src/ggml-cuda/set-rows.cuh new file mode 100644 index 000000000..c140c0873 --- /dev/null +++ b/ggml/src/ggml-cuda/set-rows.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_ROWS_BLOCK_SIZE 256 + +void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 9b730d09a..1869d19b5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -17593,22 +17593,45 @@ static void ggml_compute_forward_set_rows_f32( ggml_from_float_t const from_float = type_traits[dst->type].from_float; - for (int64_t i03 = 0; i03 < ne03; ++i03) { - for (int64_t i02 = 0; i02 < ne02; ++i02) { - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i03%ne12; - const int64_t i11 = i02%ne11; - const int64_t i10 = i; + if (src1->type == GGML_TYPE_I64) { + for (int64_t i03 = 0; i03 < ne03; ++i03) { + for (int64_t i02 = 0; i02 < ne02; ++i02) { + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i03%ne12; + const int64_t i11 = i02%ne11; + const int64_t i10 = i; - const int64_t i1 = *(int64_t*) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + const int64_t i1 = *(int64_t*) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - GGML_ASSERT(i1 >= 0 && i1 < ne1); + GGML_ASSERT(i1 >= 0 && i1 < ne1); - from_float((const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03), - ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc); + from_float((const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03), + ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc); + } + } + } + } + else if (src1->type == GGML_TYPE_I32) { + for (int64_t i03 = 0; i03 < ne03; ++i03) { + for (int64_t i02 = 0; i02 < ne02; ++i02) { + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i03%ne12; + const int64_t i11 = i02%ne11; + const int64_t i10 = i; + + const int64_t i1 = *(int32_t*) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + + GGML_ASSERT(i1 >= 0 && i1 < ne1); + + from_float((const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03), + ((char *) dst->data + i1*nb1 + i02*nb2 + i03*nb3), nc); + } } } } + else { + GGML_ABORT("Fatal error"); + } } static void ggml_compute_forward_set_rows( diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 819945fc4..02e06b07a 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -820,7 +820,7 @@ llm_expert_gating_func_type gating_op, selection_probs = logits; } - if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2) { + if (lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { auto& hparams = lctx.model.hparams; const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups; From e8a705a1538cadce092baa0acc72bb63b62f7088 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 15 Oct 2025 14:07:03 +0300 Subject: [PATCH 5/5] Bits and pieces --- ggml/src/ggml-cuda.cu | 6 +++ ggml/src/ggml-cuda/argmax.cu | 90 +++++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/argmax.cuh | 3 ++ ggml/src/ggml-cuda/argsort.cu | 54 +++++++++++++++++---- ggml/src/ggml-cuda/cpy.cu | 22 ++++++++- src/llama-build-context.cpp | 8 ++-- 6 files changed, 167 insertions(+), 16 deletions(-) create mode 100644 ggml/src/ggml-cuda/argmax.cu create mode 100644 ggml/src/ggml-cuda/argmax.cuh diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index da5df81a5..da43d4cdc 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -45,6 +45,7 @@ #include "ggml-cuda/conv2d.cuh" #include "ggml-cuda/conv2d-dw.cuh" #include "ggml-cuda/set-rows.cuh" +#include "ggml-cuda/argmax.cuh" #include #include @@ -3106,6 +3107,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg auto next = i < cgraph->n_nodes - 1 ? cgraph->nodes[i+1] : nullptr; switch (dst->op) { + case GGML_OP_ARGMAX: + ggml_cuda_argmax(ctx, dst); + break; case GGML_OP_REPEAT: ggml_cuda_op_repeat(ctx, dst); break; @@ -4272,6 +4276,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } return false; } break; + case GGML_OP_ARGMAX: + return true; case GGML_OP_DUP: case GGML_OP_REPEAT: case GGML_OP_CONCAT: diff --git a/ggml/src/ggml-cuda/argmax.cu b/ggml/src/ggml-cuda/argmax.cu new file mode 100644 index 000000000..754e133db --- /dev/null +++ b/ggml/src/ggml-cuda/argmax.cu @@ -0,0 +1,90 @@ +#include +#include + +#include "argmax.cuh" +#include "common.cuh" + +static __global__ void argmax_f32(const float * __restrict__ x, int32_t * __restrict__ dst, const int64_t ncols) { + const int64_t row = blockIdx.x; + + float maxval = -FLT_MAX; + int argmax = -1; + const float * rowx = x + row * ncols; + + for (int32_t col = threadIdx.x; col < ncols; col += blockDim.x) { + const float val = rowx[col]; + if (val > maxval) { + maxval = val; + argmax = col; + } + } + +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + + const int n_warps = blockDim.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + const int warp_id = threadIdx.x / WARP_SIZE; + if (n_warps > 1) { + constexpr int max_warps = 1024 / WARP_SIZE; + __shared__ float shared_maxval[max_warps]; + __shared__ int shared_argmax[max_warps]; + if (lane_id == 0) { + shared_maxval[warp_id] = maxval; + shared_argmax[warp_id] = argmax; + } + + __syncthreads(); + + if (warp_id == 0) { + if (lane_id < n_warps) { + maxval = shared_maxval[lane_id]; + argmax = shared_argmax[lane_id]; + } +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + const float val = __shfl_xor_sync(0xFFFFFFFF, maxval, offset, WARP_SIZE); + const int col = __shfl_xor_sync(0xFFFFFFFF, argmax, offset, WARP_SIZE); + if (val > maxval) { + maxval = val; + argmax = col; + } + } + } + } + + if (warp_id == 0 && lane_id == 0) { + dst[row] = argmax; + } +} + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + int32_t * dst_d = (int32_t *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t num_blocks = nrows; + const int64_t num_threads = std::min(1024, (ne00 + WARP_SIZE - 1) / WARP_SIZE * WARP_SIZE); + const dim3 blocks_dim(num_threads, 1, 1); + const dim3 blocks_num(num_blocks, 1, 1); + + argmax_f32<<>>(src0_d, dst_d, ne00); +} diff --git a/ggml/src/ggml-cuda/argmax.cuh b/ggml/src/ggml-cuda/argmax.cuh new file mode 100644 index 000000000..5b7223adc --- /dev/null +++ b/ggml/src/ggml-cuda/argmax.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_argmax(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index df2140822..0442dc902 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -13,9 +13,20 @@ static inline __device__ void ggml_cuda_swap(T & a, T & b) { b = tmp; } -template -static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, - int min_experts, float thresh_experts) { +struct store_ser { + constexpr static bool has_thresh = true; + int min_experts; + float thresh_experts; + store_ser(int min, float thresh) : min_experts(min), thresh_experts(thresh) {} +}; + +struct store { + constexpr static bool has_thresh = false; +}; + +template +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols, int ncols_pad, Store s) { +// int min_experts, float thresh_experts) { // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -58,19 +69,30 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n } } - if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + if constexpr (Store::has_thresh) { __syncthreads(); float max_val = x_row[dst_row[0]]; if (col < ncols) { - dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1; + dst[row * ncols + col] = col < s.min_experts || x_row[dst_row[col]] >= s.thresh_experts*max_val ? dst_row[col] : -1; } - } - else { - // copy the result to dst without the padding + } else { if (col < ncols) { dst[row * ncols + col] = dst_row[col]; } } + //if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + // __syncthreads(); + // float max_val = x_row[dst_row[0]]; + // if (col < ncols) { + // dst[row * ncols + col] = col < min_experts || x_row[dst_row[col]] >= thresh_experts*max_val ? dst_row[col] : -1; + // } + //} + //else { + // // copy the result to dst without the padding + // if (col < ncols) { + // dst[row * ncols + col] = dst_row[col]; + // } + //} } static int next_power_of_2(int x) { @@ -94,9 +116,21 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co GGML_ASSERT(shared_mem <= ggml_cuda_info().devices[ggml_cuda_get_device()].smpb); if (order == GGML_SORT_ORDER_ASC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); + if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, + {min_experts, thresh_experts}); + } else { + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, {}); + } + //k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else if (order == GGML_SORT_ORDER_DESC) { - k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); + if (min_experts >= 0 && min_experts < ncols && thresh_experts > 0) { + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, + {min_experts, thresh_experts}); + } else { + k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, {}); + } + //k_argsort_f32_i32<<>>(x, dst, ncols, ncols_pad, min_experts, thresh_experts); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 04aa3e26c..68dec9608 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -422,7 +422,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + if (src0->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -473,6 +477,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { // This is needed for MLA with mla=2 when using q8_0 cache. transpose_q8_0(ctx, src0, src1); @@ -498,7 +506,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - return nullptr; + // Prioritize CUDA graph compatibility over direct memory copy optimization. + // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. + if (src0->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; + } else { + return nullptr; + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { @@ -545,6 +559,10 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; } else if (ggml_are_same_shape(src0, src1) && src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_Q8_0) { return (void *)transpose_q8_0; } else { diff --git a/src/llama-build-context.cpp b/src/llama-build-context.cpp index 02e06b07a..e1817bae0 100644 --- a/src/llama-build-context.cpp +++ b/src/llama-build-context.cpp @@ -820,18 +820,18 @@ llm_expert_gating_func_type gating_op, selection_probs = logits; } - if (lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { + if (false && lctx.model.arch == LLM_ARCH_BAILINGMOE2 && n_tokens > 0) { auto& hparams = lctx.model.hparams; const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups; // organize experts into n_expert_groups ggml_tensor * selection_groups = ggml_view_2d(ctx, ggml_cont(ctx, ggml_transpose(ctx, selection_probs)), n_tokens * n_exp_per_group, hparams.n_expert_groups, n_tokens * n_exp_per_group * sizeof(float), 0); // [n_tokens * n_exp_per_group, n_expert_groups] #if 0 - ggml_tensor * group_scores = ggml_top_k(ctx0, selection_groups, 2); // [2, n_expert_groups] - group_scores = ggml_get_rows(ctx0, ggml_reshape_3d(ctx0, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups] + ggml_tensor * group_scores = ggml_top_k(ctx, selection_groups, 2); // [2, n_expert_groups] + group_scores = ggml_get_rows(ctx, ggml_reshape_3d(ctx, selection_groups, 1, selection_groups->ne[0], selection_groups->ne[1]), group_scores); // [1, 2, n_expert_groups] // get top n_group_used expert groups - group_scores = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_reshape_2d(ctx0, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1] + group_scores = ggml_transpose(ctx, ggml_sum_rows(ctx, ggml_reshape_2d(ctx, group_scores, group_scores->ne[1], group_scores->ne[2]))); // [n_expert_groups, 1] #else // Replace top_k(2) with argmax due to backend limitations, ideally we should use something like argmax2 instead ggml_tensor * group_scores = ggml_reshape_2d(ctx, ggml_argmax(ctx, selection_groups), 1, selection_groups->ne[1]); // [1, n_expert_groups]