Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5854,6 +5854,11 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@ModelBase.register("SeedOssForCausalLM")
class SeedOssModel(TextModel):
model_arch = gguf.MODEL_ARCH.SEED_OSS


@ModelBase.register("Olmo2ForCausalLM")
class Olmo2Model(TextModel):
model_arch = gguf.MODEL_ARCH.OLMO2
Expand Down
17 changes: 17 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ class MODEL_ARCH(IntEnum):
DREAM = auto()
SMALLTHINKER = auto()
LLADA = auto()
SEED_OSS = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -717,6 +718,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.DREAM: "dream",
MODEL_ARCH.SMALLTHINKER: "smallthinker",
MODEL_ARCH.LLADA: "llada",
MODEL_ARCH.SEED_OSS: "seed_oss",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -1973,6 +1975,21 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.SEED_OSS: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_POST_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
],
MODEL_ARCH.OLMOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
2 changes: 1 addition & 1 deletion scripts/server-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def benchmark(

logger.info("")
logger.info(f"Benchmark duration: {token_t_last:.2f} s")
logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
logger.info(f"Request throughput: {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last / 60):.2f} requests/min")
logger.info(f"Total prompt length: {np.sum(prompt_n)} tokens")
logger.info(f"Average prompt length: {np.mean(prompt_n):.2f} tokens")
logger.info(f"Average prompt latency: {1e3 * np.mean(prompt_t):.2f} ms")
Expand Down
18 changes: 18 additions & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -2068,6 +2069,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_SEED_OSS,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ enum llm_arch {
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_ARCH_LLADA,
LLM_ARCH_SEED_OSS,
LLM_ARCH_UNKNOWN,
};

Expand Down
181 changes: 181 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ const char * llm_type_name(llm_type type) {
case LLM_TYPE_355B_A32B: return "355B.A32B";
case LLM_TYPE_E2B: return "E2B";
case LLM_TYPE_E4B: return "E4B";
case LLM_TYPE_36B: return "36B";
default: return "?B";
}
}
Expand Down Expand Up @@ -1288,6 +1289,13 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_SEED_OSS:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 64: type = LLM_TYPE_36B; break;
}
} break;
case LLM_ARCH_OLMOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
Expand Down Expand Up @@ -3967,6 +3975,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
}
} break;
case LLM_ARCH_SEED_OSS:
{
const uint32_t head_dim = hparams.n_embd_head_k;
const int64_t n_qo_dim = n_head * head_dim;
const int64_t n_kv_dim = n_head_kv * head_dim;

tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);

// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);

for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0);

layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED);

layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);

layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
}
} break;

case LLM_ARCH_OLMOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
Expand Down Expand Up @@ -17934,6 +17975,140 @@ struct llm_build_lfm2 : public llm_graph_context {
}
};

struct llm_build_seed_oss : public llm_graph_context {
llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;

GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);

ggml_tensor * cur;
ggml_tensor * inpL;

inpL = build_inp_embd(model.tok_embd);

// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();

auto * inp_attn = build_attn_inp_kv();

const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;

ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;

// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);

// self-attention
{
// rope freq factors for llama3; may return nullptr for llama2 and other models
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);

// compute Q and K and RoPE them
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}

ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}

ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);

cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);

cur = build_attn(inp_attn,
model.layers[il].wo, model.layers[il].bo,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
cb(cur, "attn_out", il);
}

if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}

ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);

// feed-forward network
cur = build_norm(ffn_inp,
model.layers[il].attn_post_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);

cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);

cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);

cur = build_cvec(cur, il);
cb(cur, "l_out", il);

// input for next layer
inpL = cur;
}

cur = inpL;

cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);

cb(cur, "result_norm", -1);
res->t_embd = cur;

// lm_head
cur = build_lora_mm(model.output, cur);

cb(cur, "result_output", -1);
res->t_logits = cur;

ggml_build_forward_expand(gf, cur);
}
};

template <bool iswa>
struct llm_build_smallthinker : public llm_graph_context{
llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
Expand Down Expand Up @@ -18472,6 +18647,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
{
llm = std::make_unique<llm_build_bailingmoe>(*this, params);
} break;
case LLM_ARCH_SEED_OSS:
{
llm = std::make_unique<llm_build_seed_oss>(*this, params);
} break;
case LLM_ARCH_DOTS1:
{
llm = std::make_unique<llm_build_dots1>(*this, params);
Expand Down Expand Up @@ -18530,6 +18709,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
return llm->res->get_gf();
}


//
// interface implementation
//
Expand Down Expand Up @@ -18679,6 +18859,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_ARCEE:
case LLM_ARCH_ERNIE4_5:
case LLM_ARCH_ERNIE4_5_MOE:
case LLM_ARCH_SEED_OSS:
return LLAMA_ROPE_TYPE_NORM;

// the pairs of head values are offset by n_rot/2
Expand Down
1 change: 1 addition & 0 deletions src/llama-model.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ enum llm_type {
LLM_TYPE_355B_A32B, // GLM-4.5
LLM_TYPE_E2B,
LLM_TYPE_E4B,
LLM_TYPE_36B // Seed OSS
};

std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type);
Expand Down
Loading