@@ -1551,17 +1551,25 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15511551 } break;
15521552 case LLM_ARCH_SMALLTHINKER:
15531553 {
1554- hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1555- hparams.n_swa = 4096;
1556- hparams.set_dense_start_swa_pattern(4);
1554+ const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
15571555
1558- ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
1556+ if (found_swa && hparams.n_swa > 0) {
1557+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
1558+ hparams.n_swa = 4096;
1559+ hparams.set_dense_start_swa_pattern(4);
1560+ } else {
1561+ hparams.swa_type = LLAMA_SWA_TYPE_NONE;
1562+ hparams.n_no_rope_layer_step = hparams.n_layer;
1563+ }
15591564
1565+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
15601566 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1567+ ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false);
15611568
15621569 switch (hparams.n_layer) {
1563- default:
1564- type = LLM_TYPE_UNKNOWN;
1570+ case 32: type = LLM_TYPE_4B; break;
1571+ case 52: type = LLM_TYPE_20B; break;
1572+ default: type = LLM_TYPE_UNKNOWN;
15651573 }
15661574 } break;
15671575 default: throw std::runtime_error("unsupported model architecture");
@@ -4854,6 +4862,7 @@ void llama_model::print_info() const {
48544862
48554863 if (arch == LLM_ARCH_SMALLTHINKER) {
48564864 LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
4865+ LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
48574866 }
48584867
48594868 vocab.print_info();
@@ -14736,7 +14745,12 @@ struct llm_build_smallthinker : public llm_graph_context{
1473614745 // inp_pos - contains the positions
1473714746 ggml_tensor * inp_pos = build_inp_pos();
1473814747
14739- auto * inp_attn = build_attn_inp_kv_unified_iswa();
14748+ llm_graph_input_i * inp_attn = nullptr;
14749+ if (hparams.is_swa_any()) {
14750+ inp_attn = build_attn_inp_kv_unified_iswa();
14751+ } else {
14752+ inp_attn = build_attn_inp_kv_unified();
14753+ }
1474014754
1474114755 for (int il = 0; il < n_layer; ++il) {
1474214756 ggml_tensor * inpSA = inpL;
@@ -14747,7 +14761,11 @@ struct llm_build_smallthinker : public llm_graph_context{
1474714761 ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens]
1474814762 cb(logits, "ffn_moe_logits", il);
1474914763
14750- probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
14764+ if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
14765+ probs = ggml_soft_max(ctx0, logits); // [n_expert, n_tokens]
14766+ } else {
14767+ probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
14768+ }
1475114769 cb(probs, "ffn_moe_probs", il);
1475214770 }
1475314771
@@ -14782,15 +14800,21 @@ struct llm_build_smallthinker : public llm_graph_context{
1478214800 cb(Qcur, "Qcur", il);
1478314801 cb(Kcur, "Kcur", il);
1478414802
14785- cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur,
14786- nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
14803+ if (hparams.is_swa_any()) {
14804+ cur = build_attn(static_cast<llm_graph_input_attn_kv_unified_iswa *>(inp_attn), gf, model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur,
14805+ nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
14806+ } else {
14807+ cur = build_attn(static_cast<llm_graph_input_attn_kv_unified *>(inp_attn), gf, model.layers[il].wo, model.layers[il].bo, Qcur,Kcur, Vcur,
14808+ nullptr,nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
14809+ }
1478714810 }
1478814811
1478914812 if (il == n_layer - 1) {
1479014813 // skip computing output for unused tokens
1479114814 ggml_tensor * inp_out_ids = build_inp_out_ids();
1479214815 cur = ggml_get_rows(ctx0, cur, inp_out_ids);
1479314816 inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14817+ if (probs != nullptr) { probs = ggml_get_rows(ctx0, probs, inp_out_ids); }
1479414818 }
1479514819
1479614820 ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
0 commit comments