@@ -1564,6 +1564,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15641564 case LLM_ARCH_SMOLLM3:
15651565 {
15661566 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1567+ hparams.n_no_rope_layer_step = 4;
15671568
15681569 switch (hparams.n_layer) {
15691570 case 36: type = LLM_TYPE_3B; break;
@@ -14893,6 +14894,143 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
1489314894 }
1489414895};
1489514896
14897+ struct llm_build_smollm3 : public llm_graph_context {
14898+ llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
14899+ const int64_t n_embd_head = hparams.n_embd_head_v;
14900+
14901+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14902+ GGML_ASSERT(n_embd_head == hparams.n_rot);
14903+
14904+ ggml_tensor * cur;
14905+ ggml_tensor * inpL;
14906+
14907+ inpL = build_inp_embd(model.tok_embd);
14908+
14909+ // inp_pos - contains the positions
14910+ ggml_tensor * inp_pos = build_inp_pos();
14911+
14912+ auto * inp_attn = build_attn_inp_kv_unified();
14913+
14914+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14915+
14916+ ggml_tensor * inp_out_ids = build_inp_out_ids();
14917+
14918+ for (int il = 0; il < n_layer; ++il) {
14919+ ggml_tensor * inpSA = inpL;
14920+
14921+ const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
14922+
14923+ // norm
14924+ cur = build_norm(inpL,
14925+ model.layers[il].attn_norm, NULL,
14926+ LLM_NORM_RMS, il);
14927+ cb(cur, "attn_norm", il);
14928+
14929+ // self-attention
14930+ {
14931+ // compute Q and K and RoPE them
14932+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14933+ cb(Qcur, "Qcur", il);
14934+ if (model.layers[il].bq) {
14935+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14936+ cb(Qcur, "Qcur", il);
14937+ }
14938+
14939+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14940+ cb(Kcur, "Kcur", il);
14941+ if (model.layers[il].bk) {
14942+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14943+ cb(Kcur, "Kcur", il);
14944+ }
14945+
14946+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14947+ cb(Vcur, "Vcur", il);
14948+ if (model.layers[il].bv) {
14949+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14950+ cb(Vcur, "Vcur", il);
14951+ }
14952+
14953+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
14954+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
14955+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
14956+
14957+ if (use_rope) {
14958+ Qcur = ggml_rope_ext(
14959+ ctx0, Qcur, inp_pos, nullptr,
14960+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14961+ ext_factor, attn_factor, beta_fast, beta_slow
14962+ );
14963+
14964+ Kcur = ggml_rope_ext(
14965+ ctx0, Kcur, inp_pos, nullptr,
14966+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
14967+ ext_factor, attn_factor, beta_fast, beta_slow
14968+ );
14969+ }
14970+
14971+ cb(Qcur, "Qcur", il);
14972+ cb(Kcur, "Kcur", il);
14973+ cb(Vcur, "Vcur", il);
14974+
14975+ cur = build_attn(inp_attn, gf,
14976+ model.layers[il].wo, model.layers[il].bo,
14977+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
14978+ cb(cur, "attn_out", il);
14979+ }
14980+
14981+ if (il == n_layer - 1 && inp_out_ids) {
14982+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14983+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14984+ }
14985+
14986+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14987+ cb(ffn_inp, "ffn_inp", il);
14988+
14989+ // feed-forward network
14990+ {
14991+ cur = build_norm(ffn_inp,
14992+ model.layers[il].ffn_norm, NULL,
14993+ LLM_NORM_RMS, il);
14994+ cb(cur, "ffn_norm", il);
14995+
14996+ cur = build_ffn(cur,
14997+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
14998+ model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
14999+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
15000+ NULL,
15001+ LLM_FFN_SILU, LLM_FFN_PAR, il);
15002+ cb(cur, "ffn_out", il);
15003+ }
15004+
15005+ cur = ggml_add(ctx0, cur, ffn_inp);
15006+ cb(cur, "ffn_out", il);
15007+
15008+ cur = build_cvec(cur, il);
15009+ cb(cur, "l_out", il);
15010+
15011+ // input for next layer
15012+ inpL = cur;
15013+ }
15014+
15015+ cur = inpL;
15016+
15017+ cur = build_norm(cur,
15018+ model.output_norm, NULL,
15019+ LLM_NORM_RMS, -1);
15020+
15021+ cb(cur, "result_norm", -1);
15022+ res->t_embd = cur;
15023+
15024+ // lm_head
15025+ cur = build_lora_mm(model.output, cur);
15026+
15027+ cb(cur, "result_output", -1);
15028+ res->t_logits = cur;
15029+
15030+ ggml_build_forward_expand(gf, cur);
15031+ }
15032+ };
15033+
1489615034llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
1489715035 llama_memory_i * res;
1489815036
@@ -14996,7 +15134,6 @@ llm_graph_result_ptr llama_model::build_graph(
1499615134 llm = std::make_unique<llm_build_llama>(*this, params, gf);
1499715135 } break;
1499815136 case LLM_ARCH_LLAMA4:
14999- case LLM_ARCH_SMOLLM3:
1500015137 {
1500115138 llm = std::make_unique<llm_build_llama_iswa>(*this, params, gf);
1500215139 } break;
@@ -15278,6 +15415,10 @@ llm_graph_result_ptr llama_model::build_graph(
1527815415 {
1527915416 llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
1528015417 } break;
15418+ case LLM_ARCH_SMOLLM3:
15419+ {
15420+ llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
15421+ } break;
1528115422 default:
1528215423 GGML_ABORT("fatal error");
1528315424 }
0 commit comments