Skip to content

Commit 2ce7f2c

Browse files
author
Didik Irawan
committed
add model architecture torconsmoe
1 parent e00f3fd commit 2ce7f2c

File tree

5 files changed

+251
-0
lines changed

5 files changed

+251
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3749,6 +3749,24 @@ def set_vocab(self):
37493749
super().set_vocab()
37503750

37513751

3752+
@ModelBase.register("TorconsMoeForCausalLM")
3753+
class TorconsMoeModel(Qwen2MoeModel):
3754+
model_arch = gguf.MODEL_ARCH.TORCONSMOE
3755+
3756+
def __init__(self, *args, **kwargs):
3757+
super().__init__(*args, **kwargs)
3758+
hparams = ModelBase.load_hparams(self.dir_model, False)
3759+
self.origin_hf_arch = hparams.get('architectures', [None])[0]
3760+
3761+
def set_vocab(self):
3762+
# deal with intern-s1
3763+
if self.origin_hf_arch == 'InternS1ForConditionalGeneration':
3764+
self._set_vocab_interns1()
3765+
return
3766+
3767+
super().set_vocab()
3768+
3769+
37523770
@ModelBase.register("GPT2LMHeadModel")
37533771
class GPT2Model(TextModel):
37543772
model_arch = gguf.MODEL_ARCH.GPT2

gguf-py/gguf/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ class MODEL_ARCH(IntEnum):
335335
QWEN2VL = auto()
336336
QWEN3 = auto()
337337
QWEN3MOE = auto()
338+
TORCONSMOE = auto()
338339
PHI2 = auto()
339340
PHI3 = auto()
340341
PHIMOE = auto()
@@ -671,6 +672,7 @@ class MODEL_TENSOR(IntEnum):
671672
MODEL_ARCH.QWEN2VL: "qwen2vl",
672673
MODEL_ARCH.QWEN3: "qwen3",
673674
MODEL_ARCH.QWEN3MOE: "qwen3moe",
675+
MODEL_ARCH.TORCONSMOE: "torconsmoe",
674676
MODEL_ARCH.PHI2: "phi2",
675677
MODEL_ARCH.PHI3: "phi3",
676678
MODEL_ARCH.PHIMOE: "phimoe",
@@ -1462,6 +1464,23 @@ class MODEL_TENSOR(IntEnum):
14621464
MODEL_TENSOR.FFN_DOWN_EXP,
14631465
MODEL_TENSOR.FFN_UP_EXP,
14641466
],
1467+
MODEL_ARCH.TORCONSMOE: [
1468+
MODEL_TENSOR.TOKEN_EMBD,
1469+
MODEL_TENSOR.OUTPUT_NORM,
1470+
MODEL_TENSOR.OUTPUT,
1471+
MODEL_TENSOR.ATTN_NORM,
1472+
MODEL_TENSOR.ATTN_Q,
1473+
MODEL_TENSOR.ATTN_Q_NORM,
1474+
MODEL_TENSOR.ATTN_K,
1475+
MODEL_TENSOR.ATTN_K_NORM,
1476+
MODEL_TENSOR.ATTN_V,
1477+
MODEL_TENSOR.ATTN_OUT,
1478+
MODEL_TENSOR.FFN_NORM,
1479+
MODEL_TENSOR.FFN_GATE_INP,
1480+
MODEL_TENSOR.FFN_GATE_EXP,
1481+
MODEL_TENSOR.FFN_DOWN_EXP,
1482+
MODEL_TENSOR.FFN_UP_EXP,
1483+
],
14651484
MODEL_ARCH.PLAMO: [
14661485
MODEL_TENSOR.TOKEN_EMBD,
14671486
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
3131
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
3232
{ LLM_ARCH_QWEN3, "qwen3" },
3333
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
34+
{ LLM_ARCH_TORCONSMOE, "torconsmoe" },
3435
{ LLM_ARCH_PHI2, "phi2" },
3536
{ LLM_ARCH_PHI3, "phi3" },
3637
{ LLM_ARCH_PHIMOE, "phimoe" },
@@ -754,6 +755,26 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
754755
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
755756
},
756757
},
758+
{
759+
LLM_ARCH_TORCONSMOE,
760+
{
761+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
762+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
763+
{ LLM_TENSOR_OUTPUT, "output" },
764+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
765+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
766+
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
767+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
768+
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
769+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
770+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
771+
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
772+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
773+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
774+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
775+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
776+
},
777+
},
757778
{
758779
LLM_ARCH_PHI2,
759780
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ enum llm_arch {
3535
LLM_ARCH_QWEN2VL,
3636
LLM_ARCH_QWEN3,
3737
LLM_ARCH_QWEN3MOE,
38+
LLM_ARCH_TORCONSMOE,
3839
LLM_ARCH_PHI2,
3940
LLM_ARCH_PHI3,
4041
LLM_ARCH_PHIMOE,

src/llama-model.cpp

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,6 +999,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
999999
default: type = LLM_TYPE_UNKNOWN;
10001000
}
10011001
} break;
1002+
case LLM_ARCH_TORCONSMOE:
1003+
{
1004+
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
1005+
1006+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1007+
switch (hparams.n_layer) {
1008+
case 48: type = LLM_TYPE_30B_A3B; break;
1009+
case 94: type = LLM_TYPE_235B_A22B; break;
1010+
default: type = LLM_TYPE_UNKNOWN;
1011+
}
1012+
} break;
10021013
case LLM_ARCH_PHI2:
10031014
{
10041015
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -3223,6 +3234,50 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
32233234
// MoE branch
32243235
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
32253236

3237+
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
3238+
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
3239+
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
3240+
}
3241+
} break;
3242+
case LLM_ARCH_TORCONSMOE:
3243+
{
3244+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3245+
3246+
// output
3247+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3248+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
3249+
// if output is NULL, init from the input tok embed
3250+
if (output == NULL) {
3251+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
3252+
}
3253+
3254+
for (int i = 0; i < n_layer; ++i) {
3255+
auto & layer = layers[i];
3256+
3257+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
3258+
3259+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
3260+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
3261+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
3262+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
3263+
3264+
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
3265+
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
3266+
3267+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
3268+
3269+
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
3270+
3271+
if (n_expert == 0) {
3272+
throw std::runtime_error("n_expert must be > 0 for TORCONSMOE");
3273+
}
3274+
if (n_expert_used == 0) {
3275+
throw std::runtime_error("n_expert_used must be > 0 for TORCONSMOE");
3276+
}
3277+
3278+
// MoE branch
3279+
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
3280+
32263281
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
32273282
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
32283283
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
@@ -6143,6 +6198,10 @@ void llama_model::print_info() const {
61436198
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
61446199
}
61456200

6201+
if (arch == LLM_ARCH_TORCONSMOE || arch == LLM_ARCH_OPENAI_MOE) {
6202+
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
6203+
}
6204+
61466205
if (arch == LLM_ARCH_MINICPM ||
61476206
arch == LLM_ARCH_GRANITE ||
61486207
arch == LLM_ARCH_GRANITE_MOE ||
@@ -9276,6 +9335,134 @@ struct llm_build_qwen3moe : public llm_graph_context {
92769335
}
92779336
};
92789337

9338+
struct llm_build_torconsmoe : public llm_graph_context {
9339+
llm_build_torconsmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
9340+
const int64_t n_embd_head = hparams.n_embd_head_v;
9341+
9342+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
9343+
GGML_ASSERT(n_embd_head == hparams.n_rot);
9344+
9345+
ggml_tensor * cur;
9346+
ggml_tensor * inpL;
9347+
9348+
inpL = build_inp_embd(model.tok_embd);
9349+
9350+
// inp_pos - contains the positions
9351+
ggml_tensor * inp_pos = build_inp_pos();
9352+
9353+
auto * inp_attn = build_attn_inp_kv();
9354+
9355+
ggml_tensor * inp_out_ids = build_inp_out_ids();
9356+
9357+
for (int il = 0; il < n_layer; ++il) {
9358+
ggml_tensor * inpSA = inpL;
9359+
9360+
// norm
9361+
cur = build_norm(inpL,
9362+
model.layers[il].attn_norm, NULL,
9363+
LLM_NORM_RMS, il);
9364+
cb(cur, "attn_norm", il);
9365+
9366+
// self_attention
9367+
{
9368+
// compute Q and K and RoPE them
9369+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
9370+
cb(Qcur, "Qcur", il);
9371+
9372+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
9373+
cb(Kcur, "Kcur", il);
9374+
9375+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
9376+
cb(Vcur, "Vcur", il);
9377+
9378+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
9379+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
9380+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
9381+
9382+
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
9383+
cb(Qcur, "Qcur_normed", il);
9384+
9385+
Qcur = ggml_rope_ext(
9386+
ctx0, Qcur, inp_pos, nullptr,
9387+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
9388+
ext_factor, attn_factor, beta_fast, beta_slow
9389+
);
9390+
9391+
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
9392+
cb(Kcur, "Kcur_normed", il);
9393+
9394+
Kcur = ggml_rope_ext(
9395+
ctx0, Kcur, inp_pos, nullptr,
9396+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
9397+
ext_factor, attn_factor, beta_fast, beta_slow
9398+
);
9399+
9400+
cb(Qcur, "Qcur", il);
9401+
cb(Kcur, "Kcur", il);
9402+
cb(Vcur, "Vcur", il);
9403+
9404+
cur = build_attn(inp_attn,
9405+
model.layers[il].wo, model.layers[il].bo,
9406+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
9407+
}
9408+
9409+
if (il == n_layer - 1 && inp_out_ids) {
9410+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
9411+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
9412+
}
9413+
9414+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
9415+
cb(ffn_inp, "ffn_inp", il);
9416+
9417+
// MoE branch
9418+
cur = build_norm(ffn_inp,
9419+
model.layers[il].ffn_norm, NULL,
9420+
LLM_NORM_RMS, il);
9421+
cb(cur, "ffn_norm", il);
9422+
9423+
ggml_tensor * moe_out =
9424+
build_moe_ffn(cur,
9425+
model.layers[il].ffn_gate_inp,
9426+
model.layers[il].ffn_up_exps,
9427+
model.layers[il].ffn_gate_exps,
9428+
model.layers[il].ffn_down_exps,
9429+
nullptr,
9430+
n_expert, n_expert_used,
9431+
LLM_FFN_SILU, true,
9432+
false, 0.0,
9433+
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
9434+
il);
9435+
cb(moe_out, "ffn_moe_out", il);
9436+
cur = moe_out;
9437+
9438+
cur = ggml_add(ctx0, cur, ffn_inp);
9439+
9440+
cur = build_cvec(cur, il);
9441+
cb(cur, "l_out", il);
9442+
9443+
// input for next layer
9444+
inpL = cur;
9445+
}
9446+
9447+
cur = inpL;
9448+
9449+
cur = build_norm(cur,
9450+
model.output_norm, NULL,
9451+
LLM_NORM_RMS, -1);
9452+
9453+
cb(cur, "result_norm", -1);
9454+
res->t_embd = cur;
9455+
9456+
// lm_head
9457+
cur = build_lora_mm(model.output, cur);
9458+
9459+
cb(cur, "result_output", -1);
9460+
res->t_logits = cur;
9461+
9462+
ggml_build_forward_expand(gf, cur);
9463+
}
9464+
};
9465+
92799466
struct llm_build_phi2 : public llm_graph_context {
92809467
llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
92819468
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -19098,6 +19285,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1909819285
{
1909919286
llm = std::make_unique<llm_build_qwen3moe>(*this, params);
1910019287
} break;
19288+
case LLM_ARCH_TORCONSMOE:
19289+
{
19290+
llm = std::make_unique<llm_build_torconsmoe>(*this, params);
19291+
} break;
1910119292
case LLM_ARCH_PHI2:
1910219293
{
1910319294
llm = std::make_unique<llm_build_phi2>(*this, params);
@@ -19552,6 +19743,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1955219743
case LLM_ARCH_QWEN2MOE:
1955319744
case LLM_ARCH_QWEN3:
1955419745
case LLM_ARCH_QWEN3MOE:
19746+
case LLM_ARCH_TORCONSMOE:
1955519747
case LLM_ARCH_LLADA_MOE:
1955619748
case LLM_ARCH_OLMO2:
1955719749
case LLM_ARCH_OLMOE:

0 commit comments

Comments
 (0)