Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5854,6 +5854,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("SeedOssForCausalLM")
class SeedOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.SEED_OSS


@ModelBase.register("Olmo2ForCausalLM")
class Olmo2Model(TextModel):
model_arch = gguf.MODEL_ARCH.OLMO2
Expand Down
16 changes: 16 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ class MODEL_ARCH(IntEnum):
DREAM = auto()
SMALLTHINKER = auto()
LLADA = auto()
SEED_OSS = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -717,6 +718,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DREAM: "dream",
MODEL_ARCH.SMALLTHINKER: "smallthinker",
MODEL_ARCH.LLADA: "llada",
MODEL_ARCH.SEED_OSS: "seed_oss",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -1973,6 +1975,20 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.SEED_OSS: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
],
MODEL_ARCH.OLMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
18 changes: 18 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -2068,6 +2069,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_SEED_OSS,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ enum llm_arch {
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_LLADA,
LLM_ARCH_SEED_OSS,
LLM_ARCH_UNKNOWN,
};

Expand Down
11 changes: 11 additions & 0 deletions src/llama-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
};

llm_chat_template llm_chat_template_from_str(const std::string & name) {
Expand Down Expand Up @@ -201,6 +202,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
return LLM_CHAT_TEMPLATE_KIMI_K2;
} else if (tmpl_contains("<seed:bos>")) {
return LLM_CHAT_TEMPLATE_SEED_OSS;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
Expand Down Expand Up @@ -752,6 +755,14 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "<|im_assistant|>assistant<|im_middle|>";
}
} else if (tmpl == LLM_CHAT_TEMPLATE_SEED_OSS) {
for (auto message: chat) {
std::string role(message->role);
ss << "<seed:bos>" << role << "\n" << (role == "assistant" ? trim(message->content) : message->content) << "<seed:eos>";
}
if (add_ass) {
ss << "<seed:bos>assistant\n";
}
} else {
// template not supported
return -1;
Expand Down
1 change: 1 addition & 0 deletions src/llama-chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ enum llm_chat_template {
LLM_CHAT_TEMPLATE_OPENAI_MOE,
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_SEED_OSS,
LLM_CHAT_TEMPLATE_UNKNOWN,
};

Expand Down
182 changes: 182 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_32B: return "32B";
case LLM_TYPE_34B: return "34B";
case LLM_TYPE_35B: return "35B";
case LLM_TYPE_36B: return "36B";
case LLM_TYPE_40B: return "40B";
case LLM_TYPE_65B: return "65B";
case LLM_TYPE_70B: return "70B";
Expand Down Expand Up @@ -1288,6 +1289,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_SEED_OSS:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 64: type = LLM_TYPE_36B; break;
}
} break;
case LLM_ARCH_OLMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
Expand Down Expand Up @@ -3967,6 +3975,43 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_SEED_OSS:
{
const uint32_t head_dim = hparams.n_embd_head_k;
const int64_t n_qo_dim = n_head * head_dim;
const int64_t n_kv_dim = n_head_kv * head_dim;

tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0);

layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);

layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
}
} break;

case LLM_ARCH_OLMOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
Expand Down Expand Up @@ -17934,6 +17979,137 @@ struct llm_build_lfm2 : public llm_graph_context {
}
};

struct llm_build_seed_oss : public llm_graph_context {
llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;

GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);

ggml_tensor * cur;
ggml_tensor * inpL;

inpL = build_inp_embd(model.tok_embd);

// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_kv();

const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;

ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);

// self-attention
{
// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}

ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}

ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}

if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}

ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);

// feed-forward network
cur = build_norm(ffn_inp,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_post_norm", il);

cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);

cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);

cur = build_cvec(cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}

cur = inpL;

cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

// lm_head
cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}
};

template <bool iswa>
struct llm_build_smallthinker : public llm_graph_context{
llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
Expand Down Expand Up @@ -18472,6 +18648,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
} break;
case LLM_ARCH_SEED_OSS:
{
llm = std::make_unique<llm_build_seed_oss>(*this, params);
} break;
case LLM_ARCH_DOTS1:
{
llm = std::make_unique<llm_build_dots1>(*this, params);
Expand Down Expand Up @@ -18530,6 +18710,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
return llm->res->get_gf();
}


//
// interface implementation
//
Expand Down Expand Up @@ -18724,6 +18905,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_LFM2:
case LLM_ARCH_SMALLTHINKER:
case LLM_ARCH_GLM4_MOE:
case LLM_ARCH_SEED_OSS:
return LLAMA_ROPE_TYPE_NEOX;

case LLM_ARCH_QWEN2VL:
Expand Down
1 change: 1 addition & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ enum llm_type {
LLM_TYPE_32B,
LLM_TYPE_34B,
LLM_TYPE_35B,
LLM_TYPE_36B,
LLM_TYPE_40B,
LLM_TYPE_65B,
LLM_TYPE_70B,
Expand Down
8 changes: 8 additions & 0 deletions tests/test-chat-template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ int main(void) {
/* .bos_token= */ "",
/* .eos_token= */ "",
},
{
/* .name= */ "ByteDance-Seed/Seed-OSS-36B-Instruct",
/* .template_str */ "{# <seed:bos> #}{%- for message in messages %}{%- if message.role in [\"user\", \"system\"] %}{{ bos_token + message.role + \"\\n\" + message.content + eos_token }}{%- elif message.role == \"assistant\" %}{{ bos_token + message.role }}{%- if message.content is defined and message.content is string and message.content|trim|length > 0 %}{{ \"\\n\" + message.content|trim + eos_token }}{%- endif %}{%- else %}{{ bos_token + message.role + \"\\n\" + message.content + eos_token }}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{ bos_token + \"assistant\\n\" }}{%- endif %}",
/* .expected_output= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
/* .expected_output_jinja= */ "<seed:bos>system\nYou are a helpful assistant<seed:eos><seed:bos>user\nHello<seed:eos><seed:bos>assistant\nHi there<seed:eos><seed:bos>user\nWho are you<seed:eos><seed:bos>assistant\nI am an assistant<seed:eos><seed:bos>user\nAnother question<seed:eos><seed:bos>assistant\n",
/* .bos_token= */ "<seed:bos>",
/* .eos_token= */ "<seed:eos>",
}
};
std::vector<char> formatted_chat(1024);
int32_t res;
Expand Down
Loading