Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
33163bf
Extend the support of T5 models with different encoder-decoder layers
DamonFool Sep 10, 2025
219eada
Update convert_hf_to_gguf.py
DamonFool Sep 10, 2025
2161c30
Update gguf-py/gguf/constants.py
DamonFool Sep 10, 2025
284ceb3
Update gguf-py/gguf/gguf_writer.py
DamonFool Sep 10, 2025
77f0f16
Update src/llama-arch.cpp
DamonFool Sep 10, 2025
7efe517
Update src/llama-arch.h
DamonFool Sep 10, 2025
12a909f
Update src/llama-model.cpp
DamonFool Sep 10, 2025
634e5a9
Update src/llama-model.cpp
DamonFool Sep 10, 2025
ebef503
Update src/llama-model.cpp
DamonFool Sep 10, 2025
0acda17
Update src/llama-model.cpp
DamonFool Sep 10, 2025
19281fe
Update src/llama-hparams.h
DamonFool Sep 10, 2025
5153072
Update src/llama-model.cpp
DamonFool Sep 10, 2025
60821df
Update src/llama-model.cpp
DamonFool Sep 10, 2025
de46320
Update src/llama-model.cpp
DamonFool Sep 10, 2025
804a982
Update src/llama-model.cpp
DamonFool Sep 10, 2025
9215087
Update src/llama-model.cpp
DamonFool Sep 10, 2025
1167269
Update src/llama-model.cpp
DamonFool Sep 10, 2025
678aa48
Update src/llama-model.cpp
DamonFool Sep 10, 2025
d145ee1
Update src/llama-model.cpp
DamonFool Sep 10, 2025
ce90f80
Update src/llama-model.cpp
DamonFool Sep 10, 2025
01002df
Update src/llama-model.cpp
DamonFool Sep 10, 2025
3ee2193
Update src/llama-model.cpp
DamonFool Sep 10, 2025
6cb51f2
Update src/llama-model.cpp
DamonFool Sep 10, 2025
42f1fdb
Update src/llama-model.cpp
DamonFool Sep 10, 2025
6940650
Update src/llama-model.cpp
DamonFool Sep 10, 2025
f16d8de
Rename n_dec_layer --> dec_n_layer
DamonFool Sep 10, 2025
84e5db4
Adapt to cases when dec_n_layer > n_layer
DamonFool Sep 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6701,6 +6701,7 @@ def set_gguf_parameters(self):
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_num_decoder_layers(self.hparams["num_decoder_layers"])
self.gguf_writer.add_head_count(self.hparams["num_heads"])
self.gguf_writer.add_key_length(self.hparams["d_kv"])
self.gguf_writer.add_value_length(self.hparams["d_kv"])
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class LLM:
POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
NUM_DECODER_LAYERS = "{arch}.num_decoder_layers"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
SWIN_NORM = "{arch}.swin_norm"
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,9 @@ def add_parallel_residual(self, use: bool) -> None:
def add_decoder_start_token_id(self, id: int) -> None:
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)

def add_num_decoder_layers(self, value: int) -> None:
self.add_uint32(Keys.LLM.NUM_DECODER_LAYERS.format(arch=self.arch), value)

def add_embedding_length_per_layer_input(self, value: int) -> None:
self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)

Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_NUM_DECODER_LAYERS, "%s.num_decoder_layers" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
Expand Down
1 change: 1 addition & 0 deletions src/llama-arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ enum llm_kv {
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_NUM_DECODER_LAYERS,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_SWIN_NORM,
Expand Down
1 change: 1 addition & 0 deletions src/llama-hparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ struct llama_hparams {
// needed by encoder-decoder models (e.g. T5, FLAN-T5)
// ref: https://github.com/ggerganov/llama.cpp/pull/8141
llama_token dec_start_token_id = LLAMA_TOKEN_NULL;
uint32_t n_dec_layer = 0;

enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
Expand Down
58 changes: 39 additions & 19 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
hparams.dec_start_token_id = dec_start_token_id;
}

uint32_t num_decoder_layers;
if (ml.get_key(LLM_KV_NUM_DECODER_LAYERS, num_decoder_layers, false)) {
hparams.n_dec_layer = num_decoder_layers;
GGML_ASSERT(hparams.n_dec_layer > 0 && "T5 requires num_decoder_layers > 0");
}

switch (hparams.n_layer) {
case 6: type = LLM_TYPE_60M; break; // t5-small
case 8: type = LLM_TYPE_80M; break; // flan-t5-small
Expand Down Expand Up @@ -4414,6 +4420,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}

// n_layer: number of encoder_layers
// n_dec_layer: number of decoder_layers
const int n_dec_layer = hparams.n_dec_layer;
layers.resize(n_layer + n_dec_layer);

// load encoder layers
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];

Expand All @@ -4429,6 +4441,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}

// load decoder layers
for (int i = 0; i < n_dec_layer; ++i) {
auto & layer = layers[i + n_layer];

layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED);
Expand Down Expand Up @@ -13509,35 +13526,38 @@ struct llm_build_t5_dec : public llm_graph_context {

ggml_tensor * inp_out_ids = build_inp_out_ids();

for (int il = 0; il < n_layer; ++il) {
const int64_t n_dec_layer = hparams.n_dec_layer;

for (int il = 0; il < n_dec_layer; ++il) {
ggml_tensor * inpSA = inpL;
int il_dec = n_layer + il;

// norm
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
model.layers[il_dec].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);

// self-attention
{
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
ggml_tensor * Qcur = build_lora_mm(model.layers[il_dec].wq, cur);
cb(Qcur, "Qcur", il);

ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
ggml_tensor * Kcur = build_lora_mm(model.layers[il_dec].wk, cur);
cb(Kcur, "Kcur", il);

ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
ggml_tensor * Vcur = build_lora_mm(model.layers[il_dec].wv, cur);
cb(Vcur, "Vcur", il);

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);

ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b;
ggml_tensor * attn_rel_b = model.layers[il_dec].attn_rel_b ? model.layers[il_dec].attn_rel_b : model.layers[n_layer].attn_rel_b;
ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b);

cur = build_attn(inp_attn_self,
model.layers[il].wo, model.layers[il].bo,
model.layers[il_dec].wo, model.layers[il_dec].bo,
Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il);
cb(cur, "kqv_out", il);
}
Expand All @@ -13549,27 +13569,27 @@ struct llm_build_t5_dec : public llm_graph_context {

// norm
cur = build_norm(cur,
model.layers[il].attn_norm_cross, NULL,
model.layers[il_dec].attn_norm_cross, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm_cross", il);

// cross-attention
{
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur);
ggml_tensor * Qcur = build_lora_mm(model.layers[il_dec].wq_cross, cur);
cb(Qcur, "Qcur", il);

ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc);
ggml_tensor * Kcur = build_lora_mm(model.layers[il_dec].wk_cross, embd_enc);
cb(Kcur, "Kcur", il);

ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc);
ggml_tensor * Vcur = build_lora_mm(model.layers[il_dec].wv_cross, embd_enc);
cb(Vcur, "Vcur", il);

Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc);

cur = build_attn(inp_attn_cross,
model.layers[il].wo_cross, nullptr,
model.layers[il_dec].wo_cross, nullptr,
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il);
cb(cur, "kqv_out", il);

Expand Down Expand Up @@ -13600,7 +13620,7 @@ struct llm_build_t5_dec : public llm_graph_context {
//cb(cur, "kqv_out", il);
}

if (il == n_layer - 1 && inp_out_ids) {
if (il == n_dec_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids);
}
Expand All @@ -13611,18 +13631,18 @@ struct llm_build_t5_dec : public llm_graph_context {
// feed-forward network
{
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
model.layers[il_dec].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);

// T5 uses relu, flan-T5 uses gelu-gated
cur = build_ffn(cur,
model.layers[il].ffn_up, NULL, NULL,
model.layers[il].ffn_gate, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
model.layers[il_dec].ffn_up, NULL, NULL,
model.layers[il_dec].ffn_gate, NULL, NULL,
model.layers[il_dec].ffn_down, NULL, NULL,
NULL,
model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU,
model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ,
model.layers[il_dec].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU,
model.layers[il_dec].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ,
il);
cb(cur, "ffn_out", il);
}
Expand Down
Loading