@@ -1288,6 +1288,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12881288 default: type = LLM_TYPE_UNKNOWN;
12891289 }
12901290 } break;
1291+ case LLM_ARCH_SEED_OSS:
1292+ {
1293+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1294+ switch (hparams.n_layer) {
1295+ case 64: type = LLM_TYPE_36B; break;
1296+ }
1297+ } break;
12911298 case LLM_ARCH_OLMOE:
12921299 {
12931300 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -3967,6 +3974,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
39673974 layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
39683975 }
39693976 } break;
3977+ case LLM_ARCH_SEED_OSS:
3978+ {
3979+ const uint32_t num_heads = hparams.ssm_dt_rank;
3980+ const uint32_t intermediate_size = hparams.ssm_d_inner;
3981+ const uint32_t head_dim = intermediate_size / num_heads;
3982+ const int64_t n_qo_dim = n_head * head_dim;
3983+ const int64_t n_kv_dim = n_head_kv * head_dim;
3984+
3985+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3986+
3987+ // output
3988+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3989+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
3990+
3991+ for (int i = 0; i < n_layer; ++i) {
3992+ auto & layer = layers[i];
3993+
3994+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0);
3995+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0);
3996+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0);
3997+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0);
3998+
3999+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED);
4000+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
4001+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
4002+
4003+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4004+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
4005+
4006+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4007+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4008+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4009+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
4010+ }
4011+ } break;
4012+
39704013 case LLM_ARCH_OLMOE:
39714014 {
39724015 tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -17924,6 +17967,140 @@ struct llm_build_lfm2 : public llm_graph_context {
1792417967 }
1792517968};
1792617969
17970+ struct llm_build_seed_oss : public llm_graph_context {
17971+ llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
17972+ const int64_t n_embd_head = hparams.n_embd_head_v;
17973+
17974+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
17975+ GGML_ASSERT(n_embd_head == hparams.n_rot);
17976+
17977+ ggml_tensor * cur;
17978+ ggml_tensor * inpL;
17979+
17980+ inpL = build_inp_embd(model.tok_embd);
17981+
17982+ // inp_pos - contains the positions
17983+ ggml_tensor * inp_pos = build_inp_pos();
17984+
17985+ auto * inp_attn = build_attn_inp_kv();
17986+
17987+ const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
17988+
17989+ ggml_tensor * inp_out_ids = build_inp_out_ids();
17990+
17991+ for (int il = 0; il < n_layer; ++il) {
17992+ ggml_tensor * inpSA = inpL;
17993+
17994+ // norm
17995+ cur = build_norm(inpL,
17996+ model.layers[il].attn_norm, NULL,
17997+ LLM_NORM_RMS, il);
17998+ cb(cur, "attn_norm", il);
17999+
18000+ // self-attention
18001+ {
18002+ // rope freq factors for llama3; may return nullptr for llama2 and other models
18003+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
18004+
18005+ // compute Q and K and RoPE them
18006+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
18007+ cb(Qcur, "Qcur", il);
18008+ if (model.layers[il].bq) {
18009+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
18010+ cb(Qcur, "Qcur", il);
18011+ }
18012+
18013+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
18014+ cb(Kcur, "Kcur", il);
18015+ if (model.layers[il].bk) {
18016+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
18017+ cb(Kcur, "Kcur", il);
18018+ }
18019+
18020+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
18021+ cb(Vcur, "Vcur", il);
18022+ if (model.layers[il].bv) {
18023+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
18024+ cb(Vcur, "Vcur", il);
18025+ }
18026+
18027+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
18028+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
18029+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
18030+
18031+ Qcur = ggml_rope_ext(
18032+ ctx0, Qcur, inp_pos, rope_factors,
18033+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18034+ ext_factor, attn_factor, beta_fast, beta_slow
18035+ );
18036+
18037+ Kcur = ggml_rope_ext(
18038+ ctx0, Kcur, inp_pos, rope_factors,
18039+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18040+ ext_factor, attn_factor, beta_fast, beta_slow
18041+ );
18042+
18043+ cb(Qcur, "Qcur", il);
18044+ cb(Kcur, "Kcur", il);
18045+ cb(Vcur, "Vcur", il);
18046+
18047+ cur = build_attn(inp_attn,
18048+ model.layers[il].wo, model.layers[il].bo,
18049+ Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
18050+ cb(cur, "attn_out", il);
18051+ }
18052+
18053+ if (il == n_layer - 1 && inp_out_ids) {
18054+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
18055+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
18056+ }
18057+
18058+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
18059+ cb(ffn_inp, "ffn_inp", il);
18060+
18061+ // feed-forward network
18062+ cur = build_norm(ffn_inp,
18063+ model.layers[il].ffn_norm, NULL,
18064+ LLM_NORM_RMS, il);
18065+ cb(cur, "ffn_norm", il);
18066+
18067+ cur = build_ffn(cur,
18068+ model.layers[il].ffn_up, NULL, NULL,
18069+ model.layers[il].ffn_gate, NULL, NULL,
18070+ model.layers[il].ffn_down, NULL, NULL,
18071+ NULL,
18072+ LLM_FFN_SILU, LLM_FFN_PAR, il);
18073+ cb(cur, "ffn_out", il);
18074+
18075+ cur = ggml_add(ctx0, cur, ffn_inp);
18076+ cb(cur, "ffn_out", il);
18077+
18078+ cur = build_cvec(cur, il);
18079+ cb(cur, "l_out", il);
18080+
18081+ // input for next layer
18082+ inpL = cur;
18083+ }
18084+
18085+ cur = inpL;
18086+
18087+ cur = build_norm(cur,
18088+ model.output_norm, NULL,
18089+ LLM_NORM_RMS, -1);
18090+
18091+ cb(cur, "result_norm", -1);
18092+ res->t_embd = cur;
18093+
18094+ // lm_head
18095+ cur = build_lora_mm(model.output, cur);
18096+
18097+ cb(cur, "result_output", -1);
18098+ res->t_logits = cur;
18099+
18100+ ggml_build_forward_expand(gf, cur);
18101+ }
18102+ };
18103+
1792718104template <bool iswa>
1792818105struct llm_build_smallthinker : public llm_graph_context{
1792918106 llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
@@ -18462,6 +18639,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1846218639 {
1846318640 llm = std::make_unique<llm_build_bailingmoe>(*this, params);
1846418641 } break;
18642+ case LLM_ARCH_SEED_OSS:
18643+ {
18644+ llm = std::make_unique<llm_build_seed_oss>(*this, params);
18645+ } break;
1846518646 case LLM_ARCH_DOTS1:
1846618647 {
1846718648 llm = std::make_unique<llm_build_dots1>(*this, params);
@@ -18520,6 +18701,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1852018701 return llm->res->get_gf();
1852118702}
1852218703
18704+
1852318705//
1852418706// interface implementation
1852518707//
@@ -18689,6 +18871,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1868918871 case LLM_ARCH_QWEN3MOE:
1869018872 case LLM_ARCH_OLMO2:
1869118873 case LLM_ARCH_OLMOE:
18874+ case LLM_ARCH_SEED_OSS:
1869218875 case LLM_ARCH_PHI2:
1869318876 case LLM_ARCH_PHI3:
1869418877 case LLM_ARCH_PHIMOE:
0 commit comments