Skip to content

Commit b1afcab

Browse files
pwilkinCISC
andauthored
model : add support for Seed-OSS (ggml-org#15490)
* First draft * Fix linter errors * Added missing sinks nullptr * Don't forget the llama-arch! * We're through to the generation stage. * Fix post-attention norm * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <[email protected]> * Fix RoPE type * Fix tensor name and reorder llm_types * Update gguf-py/gguf/constants.py Remove nonexistent FFN_POST_NORM tensor Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-model.h Co-authored-by: Sigbjørn Skjæret <[email protected]> * Add basic chat template * Add chat template tests * Remake chat template test * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <[email protected]> * Update src/llama-chat.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> * Reorder llm type descriptions * Update src/llama-model.cpp Co-authored-by: Sigbjørn Skjæret <[email protected]> --------- Co-authored-by: Sigbjørn Skjæret <[email protected]>
1 parent 9ef5369 commit b1afcab

File tree

9 files changed

+244
-0
lines changed

9 files changed

+244
-0
lines changed

convert_hf_to_gguf.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5854,6 +5854,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
58545854
return [(self.map_tensor_name(name), data_torch)]
58555855

58565856

5857+
@ModelBase.register("SeedOssForCausalLM")
5858+
class SeedOssModel(TextModel):
5859+
model_arch = gguf.MODEL_ARCH.SEED_OSS
5860+
5861+
58575862
@ModelBase.register("Olmo2ForCausalLM")
58585863
class Olmo2Model(TextModel):
58595864
model_arch = gguf.MODEL_ARCH.OLMO2

gguf-py/gguf/constants.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ class MODEL_ARCH(IntEnum):
385385
DREAM = auto()
386386
SMALLTHINKER = auto()
387387
LLADA = auto()
388+
SEED_OSS = auto()
388389

389390

390391
class VISION_PROJECTOR_TYPE(IntEnum):
@@ -717,6 +718,7 @@ class MODEL_TENSOR(IntEnum):
717718
MODEL_ARCH.DREAM: "dream",
718719
MODEL_ARCH.SMALLTHINKER: "smallthinker",
719720
MODEL_ARCH.LLADA: "llada",
721+
MODEL_ARCH.SEED_OSS: "seed_oss",
720722
}
721723

722724
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -1973,6 +1975,20 @@ class MODEL_TENSOR(IntEnum):
19731975
MODEL_TENSOR.FFN_DOWN,
19741976
MODEL_TENSOR.FFN_UP,
19751977
],
1978+
MODEL_ARCH.SEED_OSS: [
1979+
MODEL_TENSOR.TOKEN_EMBD,
1980+
MODEL_TENSOR.ATTN_NORM,
1981+
MODEL_TENSOR.ATTN_Q,
1982+
MODEL_TENSOR.ATTN_K,
1983+
MODEL_TENSOR.ATTN_V,
1984+
MODEL_TENSOR.ATTN_OUT,
1985+
MODEL_TENSOR.ATTN_POST_NORM,
1986+
MODEL_TENSOR.FFN_GATE,
1987+
MODEL_TENSOR.FFN_DOWN,
1988+
MODEL_TENSOR.FFN_UP,
1989+
MODEL_TENSOR.OUTPUT_NORM,
1990+
MODEL_TENSOR.OUTPUT,
1991+
],
19761992
MODEL_ARCH.OLMOE: [
19771993
MODEL_TENSOR.TOKEN_EMBD,
19781994
MODEL_TENSOR.OUTPUT_NORM,

src/llama-arch.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
9393
{ LLM_ARCH_DREAM, "dream" },
9494
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
9595
{ LLM_ARCH_LLADA, "llada" },
96+
{ LLM_ARCH_SEED_OSS, "seed_oss" },
9697
{ LLM_ARCH_UNKNOWN, "(unknown)" },
9798
};
9899

@@ -2068,6 +2069,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
20682069
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
20692070
},
20702071
},
2072+
{
2073+
LLM_ARCH_SEED_OSS,
2074+
{
2075+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
2076+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
2077+
{ LLM_TENSOR_OUTPUT, "output" },
2078+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
2079+
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
2080+
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
2081+
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
2082+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
2083+
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
2084+
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
2085+
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
2086+
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
2087+
},
2088+
},
20712089
{
20722090
LLM_ARCH_UNKNOWN,
20732091
{

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ enum llm_arch {
9797
LLM_ARCH_DREAM,
9898
LLM_ARCH_SMALLTHINKER,
9999
LLM_ARCH_LLADA,
100+
LLM_ARCH_SEED_OSS,
100101
LLM_ARCH_UNKNOWN,
101102
};
102103

src/llama-chat.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
6969
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
7070
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
7171
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
72+
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
7273
};
7374

7475
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -201,6 +202,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
201202
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
202203
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
203204
return LLM_CHAT_TEMPLATE_KIMI_K2;
205+
} else if (tmpl_contains("<seed:bos>")) {
206+
return LLM_CHAT_TEMPLATE_SEED_OSS;
204207
}
205208
return LLM_CHAT_TEMPLATE_UNKNOWN;
206209
}
@@ -752,6 +755,14 @@ int32_t llm_chat_apply_template(
752755
if (add_ass) {
753756
ss << "<|im_assistant|>assistant<|im_middle|>";
754757
}
758+
} else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) {
759+
for (auto message: chat) {
760+
std::string role(message->role);
761+
ss << "<seed:bos>" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << "<seed:eos>";
762+
}
763+
if (add_ass) {
764+
ss << "<seed:bos>assistant\n";
765+
}
755766
} else {
756767
// template not supported
757768
return -1;

src/llama-chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ enum llm_chat_template {
4949
LLM_CHAT_TEMPLATE_OPENAI_MOE,
5050
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
5151
LLM_CHAT_TEMPLATE_KIMI_K2,
52+
LLM_CHAT_TEMPLATE_SEED_OSS,
5253
LLM_CHAT_TEMPLATE_UNKNOWN,
5354
};
5455

src/llama-model.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ const char * llm_type_name(llm_type type) {
8383
case LLM_TYPE_32B: return "32B";
8484
case LLM_TYPE_34B: return "34B";
8585
case LLM_TYPE_35B: return "35B";
86+
case LLM_TYPE_36B: return "36B";
8687
case LLM_TYPE_40B: return "40B";
8788
case LLM_TYPE_65B: return "65B";
8889
case LLM_TYPE_70B: return "70B";
@@ -1288,6 +1289,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
12881289
default: type = LLM_TYPE_UNKNOWN;
12891290
}
12901291
} break;
1292+
case LLM_ARCH_SEED_OSS:
1293+
{
1294+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1295+
switch (hparams.n_layer) {
1296+
case 64: type = LLM_TYPE_36B; break;
1297+
default: type = LLM_TYPE_UNKNOWN;
1298+
}
1299+
} break;
12911300
case LLM_ARCH_OLMOE:
12921301
{
12931302
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -3967,6 +3976,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
39673976
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
39683977
}
39693978
} break;
3979+
case LLM_ARCH_SEED_OSS:
3980+
{
3981+
const uint32_t head_dim = hparams.n_embd_head_k;
3982+
const int64_t n_qo_dim = n_head * head_dim;
3983+
const int64_t n_kv_dim = n_head_kv * head_dim;
3984+
3985+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
3986+
3987+
// output
3988+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
3989+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
3990+
// if output is NULL, init from the input tok embed
3991+
if (output == NULL) {
3992+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
3993+
}
3994+
3995+
for (int i = 0; i < n_layer; ++i) {
3996+
auto & layer = layers[i];
3997+
3998+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0);
3999+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0);
4000+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0);
4001+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0);
4002+
4003+
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED);
4004+
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
4005+
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
4006+
4007+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4008+
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
4009+
4010+
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
4011+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
4012+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
4013+
}
4014+
} break;
4015+
39704016
case LLM_ARCH_OLMOE:
39714017
{
39724018
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -17934,6 +17980,137 @@ struct llm_build_lfm2 : public llm_graph_context {
1793417980
}
1793517981
};
1793617982

17983+
struct llm_build_seed_oss : public llm_graph_context {
17984+
llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
17985+
const int64_t n_embd_head = hparams.n_embd_head_v;
17986+
17987+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
17988+
GGML_ASSERT(n_embd_head == hparams.n_rot);
17989+
17990+
ggml_tensor * cur;
17991+
ggml_tensor * inpL;
17992+
17993+
inpL = build_inp_embd(model.tok_embd);
17994+
17995+
// inp_pos - contains the positions
17996+
ggml_tensor * inp_pos = build_inp_pos();
17997+
17998+
auto * inp_attn = build_attn_inp_kv();
17999+
18000+
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
18001+
18002+
ggml_tensor * inp_out_ids = build_inp_out_ids();
18003+
18004+
for (int il = 0; il < n_layer; ++il) {
18005+
ggml_tensor * inpSA = inpL;
18006+
18007+
// norm
18008+
cur = build_norm(inpL,
18009+
model.layers[il].attn_norm, NULL,
18010+
LLM_NORM_RMS, il);
18011+
cb(cur, "attn_norm", il);
18012+
18013+
// self-attention
18014+
{
18015+
// compute Q and K and RoPE them
18016+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
18017+
cb(Qcur, "Qcur", il);
18018+
if (model.layers[il].bq) {
18019+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
18020+
cb(Qcur, "Qcur", il);
18021+
}
18022+
18023+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
18024+
cb(Kcur, "Kcur", il);
18025+
if (model.layers[il].bk) {
18026+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
18027+
cb(Kcur, "Kcur", il);
18028+
}
18029+
18030+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
18031+
cb(Vcur, "Vcur", il);
18032+
if (model.layers[il].bv) {
18033+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
18034+
cb(Vcur, "Vcur", il);
18035+
}
18036+
18037+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
18038+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
18039+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
18040+
18041+
Qcur = ggml_rope_ext(
18042+
ctx0, Qcur, inp_pos, nullptr,
18043+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18044+
ext_factor, attn_factor, beta_fast, beta_slow
18045+
);
18046+
18047+
Kcur = ggml_rope_ext(
18048+
ctx0, Kcur, inp_pos, nullptr,
18049+
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
18050+
ext_factor, attn_factor, beta_fast, beta_slow
18051+
);
18052+
18053+
cb(Qcur, "Qcur", il);
18054+
cb(Kcur, "Kcur", il);
18055+
cb(Vcur, "Vcur", il);
18056+
18057+
cur = build_attn(inp_attn,
18058+
model.layers[il].wo, model.layers[il].bo,
18059+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
18060+
cb(cur, "attn_out", il);
18061+
}
18062+
18063+
if (il == n_layer - 1 && inp_out_ids) {
18064+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
18065+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
18066+
}
18067+
18068+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
18069+
cb(ffn_inp, "ffn_inp", il);
18070+
18071+
// feed-forward network
18072+
cur = build_norm(ffn_inp,
18073+
model.layers[il].attn_post_norm, NULL,
18074+
LLM_NORM_RMS, il);
18075+
cb(cur, "attn_post_norm", il);
18076+
18077+
cur = build_ffn(cur,
18078+
model.layers[il].ffn_up, NULL, NULL,
18079+
model.layers[il].ffn_gate, NULL, NULL,
18080+
model.layers[il].ffn_down, NULL, NULL,
18081+
NULL,
18082+
LLM_FFN_SILU, LLM_FFN_PAR, il);
18083+
cb(cur, "ffn_out", il);
18084+
18085+
cur = ggml_add(ctx0, cur, ffn_inp);
18086+
cb(cur, "ffn_out", il);
18087+
18088+
cur = build_cvec(cur, il);
18089+
cb(cur, "l_out", il);
18090+
18091+
// input for next layer
18092+
inpL = cur;
18093+
}
18094+
18095+
cur = inpL;
18096+
18097+
cur = build_norm(cur,
18098+
model.output_norm, NULL,
18099+
LLM_NORM_RMS, -1);
18100+
18101+
cb(cur, "result_norm", -1);
18102+
res->t_embd = cur;
18103+
18104+
// lm_head
18105+
cur = build_lora_mm(model.output, cur);
18106+
18107+
cb(cur, "result_output", -1);
18108+
res->t_logits = cur;
18109+
18110+
ggml_build_forward_expand(gf, cur);
18111+
}
18112+
};
18113+
1793718114
template <bool iswa>
1793818115
struct llm_build_smallthinker : public llm_graph_context{
1793918116
llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
@@ -18472,6 +18649,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1847218649
{
1847318650
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
1847418651
} break;
18652+
case LLM_ARCH_SEED_OSS:
18653+
{
18654+
llm = std::make_unique<llm_build_seed_oss>(*this, params);
18655+
} break;
1847518656
case LLM_ARCH_DOTS1:
1847618657
{
1847718658
llm = std::make_unique<llm_build_dots1>(*this, params);
@@ -18530,6 +18711,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1853018711
return llm->res->get_gf();
1853118712
}
1853218713

18714+
1853318715
//
1853418716
// interface implementation
1853518717
//
@@ -18724,6 +18906,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1872418906
case LLM_ARCH_LFM2:
1872518907
case LLM_ARCH_SMALLTHINKER:
1872618908
case LLM_ARCH_GLM4_MOE:
18909+
case LLM_ARCH_SEED_OSS:
1872718910
return LLAMA_ROPE_TYPE_NEOX;
1872818911

1872918912
case LLM_ARCH_QWEN2VL:

src/llama-model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ enum llm_type {
7676
LLM_TYPE_32B,
7777
LLM_TYPE_34B,
7878
LLM_TYPE_35B,
79+
LLM_TYPE_36B,
7980
LLM_TYPE_40B,
8081
LLM_TYPE_65B,
8182
LLM_TYPE_70B,

tests/test-chat-template.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,14 @@ int main(void) {
290290
/* .bos_token= */ "",
291291
/* .eos_token= */ "",
292292
},
293+
{
294+
/* .name= */ "ByteDance-Seed/Seed-OSS-36B-Instruct",
295+
/* .template_str */ "{# <seed:bos> #}{%- for message in messages %}{%- if message.role in [\"user\", \"system\"] %}{{ bos_token + message.role + \"\\n\" + message.content + eos_token }}{%- elif message.role == \"assistant\" %}{{ bos_token + message.role }}{%- if message.content is defined and message.content is string and message.content|trim|length > 0 %}{{ \"\\n\" + message.content|trim + eos_token }}{%- endif %}{%- else %}{{ bos_token + message.role + \"\\n\" + message.content + eos_token }}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{ bos_token + \"assistant\\n\" }}{%- endif %}",
296+
/* .expected_output= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
297+
/* .expected_output_jinja= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
298+
/* .bos_token= */ "<seed:bos>",
299+
/* .eos_token= */ "<seed:eos>",
300+
}
293301
};
294302
std::vector<char> formatted_chat(1024);
295303
int32_t res;

0 commit comments

Comments
 (0)