@@ -8489,6 +8489,100 @@ ggml_cgraph* llm_build_context::build_minimaxm2() {
84898489 return gf;
84908490}
84918491
8492+ ggml_cgraph* llm_build_context::build_smollm3 () {
8493+ ggml_cgraph * gf = ggml_new_graph_custom (ctx0, llama_model_max_nodes (model), false );
8494+ const int64_t n_embd_head = hparams.n_embd_head_v ;
8495+ GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
8496+ // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64
8497+
8498+ ggml_tensor * cur;
8499+ ggml_tensor * inpL;
8500+
8501+ inpL = llm_build_inp_embd (ctx0, lctx, hparams, batch, model.tok_embd , cb);
8502+
8503+ ggml_tensor * inp_pos = build_inp_pos ();
8504+
8505+
8506+ // auto * inp_attn = build_attn_inp_kv();
8507+ ggml_tensor * inp_out_ids = build_inp_out_ids ();
8508+ ggml_tensor * KQ_mask = build_inp_KQ_mask ();
8509+
8510+ const float kq_scale = hparams.f_attention_scale == 0 .0f ? 1 .0f /sqrtf (float (n_embd_head)) : hparams.f_attention_scale ;
8511+
8512+ for (int il = 0 ; il < n_layer; ++il) {
8513+ ggml_tensor * inpSA = inpL;
8514+
8515+ const bool use_rope = (il + 1 ) % hparams.n_no_rope_layer_step != 0 ;
8516+
8517+ // norm
8518+ cur = llm_build_norm (ctx0, inpL, hparams, model.layers [il].attn_norm , NULL , LLM_NORM_RMS, cb, il);
8519+ cb (cur, " attn_norm" , il);
8520+
8521+ // self-attention
8522+ {
8523+ auto [Qcur, Kcur, Vcur] = llm_build_mul_mat_qkv (gf, cur,
8524+ model.layers [il].wqkv , model.layers [il].bqkv ,
8525+ model.layers [il].wqk , model.layers [il].bqk ,
8526+ model.layers [il].wq , model.layers [il].bq ,
8527+ model.layers [il].wk , model.layers [il].bk ,
8528+ model.layers [il].wv , model.layers [il].bv ,
8529+ model.layers [il].attn_q_norm , model.layers [il].attn_k_norm , 0 , il);
8530+
8531+ if (use_rope) {
8532+ Qcur = ggml_rope_ext (ctx0, Qcur, inp_pos, nullptr , n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8533+ ext_factor, attn_factor, beta_fast, beta_slow);
8534+ cb (Qcur, " Qcur" , il);
8535+
8536+ Kcur = ggml_rope_ext (ctx0, Kcur, inp_pos, nullptr , n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
8537+ ext_factor, attn_factor, beta_fast, beta_slow);
8538+ cb (Kcur, " Kcur" , il);
8539+ }
8540+
8541+ cur = llm_build_kv (ctx0, lctx, kv_self, gf,
8542+ model.layers [il].wo , model.layers [il].bo ,
8543+ Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
8544+ cb (cur, " attn_out" , il);
8545+ }
8546+ if (il == n_layer - 1 && inp_out_ids) {
8547+ cur = ggml_get_rows (ctx0, cur, inp_out_ids);
8548+ inpSA = ggml_get_rows (ctx0, inpSA, inp_out_ids);
8549+ }
8550+ ggml_tensor * ffn_inp = ggml_add (ctx0, cur, inpSA);
8551+ cb (ffn_inp, " ffn_inp" , il);
8552+
8553+ // feed-forward network
8554+ cur = llm_build_norm (ctx0, ffn_inp, hparams, model.layers [il].ffn_norm , NULL , LLM_NORM_RMS, cb, il);
8555+ cb (cur, " ffn_norm" , il);
8556+
8557+ cur = llm_build_ffn (ctx0, lctx, cur,
8558+ model.layers [il].ffn_up , NULL , NULL ,
8559+ model.layers [il].ffn_gate , NULL , NULL ,
8560+ model.layers [il].ffn_down , NULL , NULL ,
8561+ NULL ,
8562+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
8563+ cb (cur, " ffn_out" , il);
8564+
8565+ cur = ggml_add (ctx0, cur, ffn_inp);
8566+ cur = lctx.cvec .apply_to (ctx0, cur, il);
8567+ cb (cur, " l_out" , il);
8568+
8569+ // input for next layer
8570+ inpL = cur;
8571+ }
8572+ cur = inpL;
8573+
8574+ cur = llm_build_norm (ctx0, cur, hparams, model.output_norm , NULL , LLM_NORM_RMS, cb, -1 );
8575+ cb (cur, " result_norm" , -1 );
8576+
8577+ // lm_head
8578+ cur = llm_build_lora_mm (lctx, ctx0, model.output , cur);
8579+ cb (cur, " result_output" , -1 );
8580+
8581+ ggml_build_forward_expand (gf, cur);
8582+
8583+ return gf;
8584+ }
8585+
84928586ggml_cgraph * llm_build_context::llama_build_graph_defrag (llama_context & lctx, const std::vector<uint32_t > & ids) {
84938587 llama_batch dummy;
84948588 dummy.n_tokens = 0 ;
@@ -8839,6 +8933,10 @@ ggml_cgraph * llm_build_context::llama_build_graph(
88398933 {
88408934 result = llm.build_minimaxm2 ();
88418935 } break ;
8936+ case LLM_ARCH_SMOLLM3:
8937+ {
8938+ result = llm.build_smollm3 ();
8939+ } break ;
88428940 default :
88438941 GGML_ABORT (" fatal error" );
88448942 }
0 commit comments