Skip to content

Commit 93b3228

Browse files
arch : add T5Gemma encoder-decoder architecture support (#14940)
1 parent fd1234c commit 93b3228

File tree

8 files changed

+581
-2
lines changed

8 files changed

+581
-2
lines changed

convert_hf_to_gguf.py

Lines changed: 328 additions & 0 deletions
Large diffs are not rendered by default.

gguf-py/gguf/constants.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,7 @@ class MODEL_ARCH(IntEnum):
362362
BITNET = auto()
363363
T5 = auto()
364364
T5ENCODER = auto()
365+
T5GEMMA = auto() # T5Gemma architecture
365366
JAIS = auto()
366367
NEMOTRON = auto()
367368
EXAONE = auto()
@@ -528,6 +529,12 @@ class MODEL_TENSOR(IntEnum):
528529
DEC_FFN_DOWN = auto()
529530
DEC_FFN_UP = auto()
530531
DEC_OUTPUT_NORM = auto()
532+
# T5GEMMA specific post layer normalization tensors
533+
DEC_POST_SELF_ATTN_NORM = auto()
534+
DEC_POST_CROSS_ATTN_NORM = auto()
535+
DEC_POST_FFN_NORM = auto()
536+
ENC_POST_SELF_ATTN_NORM = auto()
537+
ENC_POST_FFN_NORM = auto()
531538
ENC_ATTN_NORM = auto()
532539
ENC_ATTN_Q = auto()
533540
ENC_ATTN_K = auto()
@@ -693,6 +700,7 @@ class MODEL_TENSOR(IntEnum):
693700
MODEL_ARCH.BITNET: "bitnet",
694701
MODEL_ARCH.T5: "t5",
695702
MODEL_ARCH.T5ENCODER: "t5encoder",
703+
MODEL_ARCH.T5GEMMA: "t5gemma", # T5Gemma architecture
696704
MODEL_ARCH.JAIS: "jais",
697705
MODEL_ARCH.NEMOTRON: "nemotron",
698706
MODEL_ARCH.EXAONE: "exaone",
@@ -860,6 +868,12 @@ class MODEL_TENSOR(IntEnum):
860868
MODEL_TENSOR.DEC_FFN_DOWN: "dec.blk.{bid}.ffn_down",
861869
MODEL_TENSOR.DEC_FFN_UP: "dec.blk.{bid}.ffn_up",
862870
MODEL_TENSOR.DEC_OUTPUT_NORM: "dec.output_norm",
871+
# T5GEMMA specific post layer normalization tensors
872+
MODEL_TENSOR.DEC_POST_SELF_ATTN_NORM: "dec.blk.{bid}.post_self_attn_norm",
873+
MODEL_TENSOR.DEC_POST_CROSS_ATTN_NORM: "dec.blk.{bid}.post_cross_attn_norm",
874+
MODEL_TENSOR.DEC_POST_FFN_NORM: "dec.blk.{bid}.post_ffn_norm",
875+
MODEL_TENSOR.ENC_POST_SELF_ATTN_NORM: "enc.blk.{bid}.post_self_attn_norm",
876+
MODEL_TENSOR.ENC_POST_FFN_NORM: "enc.blk.{bid}.post_ffn_norm",
863877
MODEL_TENSOR.ENC_ATTN_NORM: "enc.blk.{bid}.attn_norm",
864878
MODEL_TENSOR.ENC_ATTN_Q: "enc.blk.{bid}.attn_q",
865879
MODEL_TENSOR.ENC_ATTN_K: "enc.blk.{bid}.attn_k",
@@ -2238,6 +2252,45 @@ class MODEL_TENSOR(IntEnum):
22382252
MODEL_TENSOR.ENC_FFN_UP,
22392253
MODEL_TENSOR.ENC_OUTPUT_NORM,
22402254
],
2255+
MODEL_ARCH.T5GEMMA: [
2256+
MODEL_TENSOR.TOKEN_EMBD,
2257+
MODEL_TENSOR.OUTPUT_NORM,
2258+
MODEL_TENSOR.OUTPUT,
2259+
MODEL_TENSOR.DEC_ATTN_NORM,
2260+
MODEL_TENSOR.DEC_ATTN_Q,
2261+
MODEL_TENSOR.DEC_ATTN_K,
2262+
MODEL_TENSOR.DEC_ATTN_V,
2263+
MODEL_TENSOR.DEC_ATTN_OUT,
2264+
MODEL_TENSOR.DEC_ATTN_REL_B,
2265+
MODEL_TENSOR.DEC_CROSS_ATTN_NORM,
2266+
MODEL_TENSOR.DEC_CROSS_ATTN_Q,
2267+
MODEL_TENSOR.DEC_CROSS_ATTN_K,
2268+
MODEL_TENSOR.DEC_CROSS_ATTN_V,
2269+
MODEL_TENSOR.DEC_CROSS_ATTN_OUT,
2270+
MODEL_TENSOR.DEC_CROSS_ATTN_REL_B,
2271+
MODEL_TENSOR.DEC_FFN_NORM,
2272+
MODEL_TENSOR.DEC_FFN_GATE,
2273+
MODEL_TENSOR.DEC_FFN_DOWN,
2274+
MODEL_TENSOR.DEC_FFN_UP,
2275+
MODEL_TENSOR.DEC_OUTPUT_NORM,
2276+
MODEL_TENSOR.ENC_ATTN_NORM,
2277+
MODEL_TENSOR.ENC_ATTN_Q,
2278+
MODEL_TENSOR.ENC_ATTN_K,
2279+
MODEL_TENSOR.ENC_ATTN_V,
2280+
MODEL_TENSOR.ENC_ATTN_OUT,
2281+
MODEL_TENSOR.ENC_ATTN_REL_B,
2282+
MODEL_TENSOR.ENC_FFN_NORM,
2283+
MODEL_TENSOR.ENC_FFN_GATE,
2284+
MODEL_TENSOR.ENC_FFN_DOWN,
2285+
MODEL_TENSOR.ENC_FFN_UP,
2286+
MODEL_TENSOR.ENC_OUTPUT_NORM,
2287+
# T5GEMMA specific post layer normalization tensors
2288+
MODEL_TENSOR.DEC_POST_SELF_ATTN_NORM,
2289+
MODEL_TENSOR.DEC_POST_CROSS_ATTN_NORM,
2290+
MODEL_TENSOR.DEC_POST_FFN_NORM,
2291+
MODEL_TENSOR.ENC_POST_SELF_ATTN_NORM,
2292+
MODEL_TENSOR.ENC_POST_FFN_NORM,
2293+
],
22412294
MODEL_ARCH.JAIS: [
22422295
MODEL_TENSOR.TOKEN_EMBD,
22432296
MODEL_TENSOR.OUTPUT_NORM,

gguf-py/gguf/tensor_mapping.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class TensorNameMap:
2727
"embedding.word_embeddings", # chatglm
2828
"transformer.token_embeddings", # openelm
2929
"shared", # t5
30+
"model.decoder.embed_tokens", # t5gemma
31+
"model.encoder.embed_tokens", # t5gemma
3032
"rwkv.embeddings", # rwkv6
3133
"model.embeddings", # rwkv7
3234
"model.word_embeddings", # bailingmoe
@@ -887,22 +889,27 @@ class TensorNameMap:
887889

888890
MODEL_TENSOR.DEC_ATTN_NORM: (
889891
"decoder.block.{bid}.layer.0.layer_norm", # t5
892+
"model.decoder.layers.{bid}.pre_self_attn_layernorm", # t5gemma
890893
),
891894

892895
MODEL_TENSOR.DEC_ATTN_Q: (
893896
"decoder.block.{bid}.layer.0.SelfAttention.q", # t5
897+
"model.decoder.layers.{bid}.self_attn.q_proj", # t5gemma
894898
),
895899

896900
MODEL_TENSOR.DEC_ATTN_K: (
897901
"decoder.block.{bid}.layer.0.SelfAttention.k", # t5
902+
"model.decoder.layers.{bid}.self_attn.k_proj", # t5gemma
898903
),
899904

900905
MODEL_TENSOR.DEC_ATTN_V: (
901906
"decoder.block.{bid}.layer.0.SelfAttention.v", # t5
907+
"model.decoder.layers.{bid}.self_attn.v_proj", # t5gemma
902908
),
903909

904910
MODEL_TENSOR.DEC_ATTN_OUT: (
905911
"decoder.block.{bid}.layer.0.SelfAttention.o", # t5
912+
"model.decoder.layers.{bid}.self_attn.o_proj", # t5gemma
906913
),
907914

908915
MODEL_TENSOR.DEC_ATTN_REL_B: (
@@ -911,22 +918,27 @@ class TensorNameMap:
911918

912919
MODEL_TENSOR.DEC_CROSS_ATTN_NORM: (
913920
"decoder.block.{bid}.layer.1.layer_norm", # t5
921+
"model.decoder.layers.{bid}.pre_cross_attn_layernorm", # t5gemma
914922
),
915923

916924
MODEL_TENSOR.DEC_CROSS_ATTN_Q: (
917925
"decoder.block.{bid}.layer.1.EncDecAttention.q", # t5
926+
"model.decoder.layers.{bid}.cross_attn.q_proj", # t5gemma
918927
),
919928

920929
MODEL_TENSOR.DEC_CROSS_ATTN_K: (
921930
"decoder.block.{bid}.layer.1.EncDecAttention.k", # t5
931+
"model.decoder.layers.{bid}.cross_attn.k_proj", # t5gemma
922932
),
923933

924934
MODEL_TENSOR.DEC_CROSS_ATTN_V: (
925935
"decoder.block.{bid}.layer.1.EncDecAttention.v", # t5
936+
"model.decoder.layers.{bid}.cross_attn.v_proj", # t5gemma
926937
),
927938

928939
MODEL_TENSOR.DEC_CROSS_ATTN_OUT: (
929940
"decoder.block.{bid}.layer.1.EncDecAttention.o", # t5
941+
"model.decoder.layers.{bid}.cross_attn.o_proj", # t5gemma
930942
),
931943

932944
MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: (
@@ -935,43 +947,70 @@ class TensorNameMap:
935947

936948
MODEL_TENSOR.DEC_FFN_NORM: (
937949
"decoder.block.{bid}.layer.2.layer_norm", # t5
950+
"model.decoder.layers.{bid}.pre_feedforward_layernorm", # t5gemma
938951
),
939952

940953
MODEL_TENSOR.DEC_FFN_GATE: (
941954
"decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5
955+
"model.decoder.layers.{bid}.mlp.gate_proj", # t5gemma
942956
),
943957

944958
MODEL_TENSOR.DEC_FFN_UP: (
945959
"decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5
946960
"decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5
961+
"model.decoder.layers.{bid}.mlp.up_proj", # t5gemma
947962
),
948963

949964
MODEL_TENSOR.DEC_FFN_DOWN: (
950965
"decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5
966+
"model.decoder.layers.{bid}.mlp.down_proj", # t5gemma
951967
),
952968

953969
MODEL_TENSOR.DEC_OUTPUT_NORM: (
954970
"decoder.final_layer_norm", # t5
971+
"model.decoder.norm", # t5gemma
972+
),
973+
974+
# T5GEMMA specific post layer normalization tensors
975+
MODEL_TENSOR.DEC_POST_SELF_ATTN_NORM: (
976+
"model.decoder.layers.{bid}.post_self_attn_layernorm", # t5gemma
977+
),
978+
MODEL_TENSOR.DEC_POST_CROSS_ATTN_NORM: (
979+
"model.decoder.layers.{bid}.post_cross_attn_layernorm", # t5gemma
980+
),
981+
MODEL_TENSOR.DEC_POST_FFN_NORM: (
982+
"model.decoder.layers.{bid}.post_feedforward_layernorm", # t5gemma
983+
),
984+
MODEL_TENSOR.ENC_POST_SELF_ATTN_NORM: (
985+
"model.encoder.layers.{bid}.post_self_attn_layernorm", # t5gemma
986+
),
987+
MODEL_TENSOR.ENC_POST_FFN_NORM: (
988+
"model.encoder.layers.{bid}.post_feedforward_layernorm", # t5gemma
955989
),
956990

957991
MODEL_TENSOR.ENC_ATTN_NORM: (
958992
"encoder.block.{bid}.layer.0.layer_norm", # t5
993+
"model.encoder.layers.{bid}.pre_self_attn_layernorm", # t5gemma
959994
),
960995

961996
MODEL_TENSOR.ENC_ATTN_Q: (
962997
"encoder.block.{bid}.layer.0.SelfAttention.q", # t5
998+
"model.encoder.layers.{bid}.self_attn.q_proj", # t5gemma
963999
),
9641000

9651001
MODEL_TENSOR.ENC_ATTN_K: (
9661002
"encoder.block.{bid}.layer.0.SelfAttention.k", # t5
1003+
"model.encoder.layers.{bid}.self_attn.k_proj", # t5gemma
9671004
),
9681005

9691006
MODEL_TENSOR.ENC_ATTN_V: (
9701007
"encoder.block.{bid}.layer.0.SelfAttention.v", # t5
1008+
"model.encoder.layers.{bid}.self_attn.v_proj", # t5gemma
9711009
),
9721010

9731011
MODEL_TENSOR.ENC_ATTN_OUT: (
9741012
"encoder.block.{bid}.layer.0.SelfAttention.o", # t5
1013+
"model.encoder.layers.{bid}.self_attn.o_proj", # t5gemma
9751014
),
9761015

9771016
MODEL_TENSOR.ENC_ATTN_REL_B: (
@@ -980,25 +1019,30 @@ class TensorNameMap:
9801019

9811020
MODEL_TENSOR.ENC_FFN_NORM: (
9821021
"encoder.block.{bid}.layer.1.layer_norm", # t5
1022+
"model.encoder.layers.{bid}.pre_feedforward_layernorm", # t5gemma
9831023
),
9841024

9851025
MODEL_TENSOR.ENC_FFN_GATE: (
9861026
"encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5
1027+
"model.encoder.layers.{bid}.mlp.gate_proj", # t5gemma
9871028
),
9881029

9891030
MODEL_TENSOR.ENC_FFN_UP: (
9901031
"encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5
9911032
"encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5
1033+
"model.encoder.layers.{bid}.mlp.up_proj", # t5gemma
9921034
),
9931035

9941036
MODEL_TENSOR.ENC_FFN_DOWN: (
9951037
"encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
1038+
"model.encoder.layers.{bid}.mlp.down_proj", # t5gemma
9961039
),
9971040

9981041
############################################################################
9991042
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
10001043
MODEL_TENSOR.ENC_OUTPUT_NORM: (
10011044
"encoder.final_layer_norm", # t5
1045+
"model.encoder.norm", # t5gemma
10021046
"layer_norm", # neobert
10031047
),
10041048

src/llama-arch.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6666
{ LLM_ARCH_BITNET, "bitnet" },
6767
{ LLM_ARCH_T5, "t5" },
6868
{ LLM_ARCH_T5ENCODER, "t5encoder" },
69+
{ LLM_ARCH_T5GEMMA, "t5gemma" },
6970
{ LLM_ARCH_JAIS, "jais" },
7071
{ LLM_ARCH_NEMOTRON, "nemotron" },
7172
{ LLM_ARCH_EXAONE, "exaone" },
@@ -1499,6 +1500,46 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
14991500
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
15001501
},
15011502
},
1503+
{
1504+
LLM_ARCH_T5GEMMA,
1505+
{
1506+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
1507+
{ LLM_TENSOR_OUTPUT, "output" },
1508+
{ LLM_TENSOR_DEC_OUTPUT_NORM, "dec.output_norm" },
1509+
{ LLM_TENSOR_DEC_ATTN_NORM, "dec.blk.%d.attn_norm" },
1510+
{ LLM_TENSOR_DEC_ATTN_Q, "dec.blk.%d.attn_q" },
1511+
{ LLM_TENSOR_DEC_ATTN_K, "dec.blk.%d.attn_k" },
1512+
{ LLM_TENSOR_DEC_ATTN_V, "dec.blk.%d.attn_v" },
1513+
{ LLM_TENSOR_DEC_ATTN_OUT, "dec.blk.%d.attn_o" },
1514+
{ LLM_TENSOR_DEC_ATTN_REL_B, "dec.blk.%d.attn_rel_b" },
1515+
{ LLM_TENSOR_DEC_CROSS_ATTN_NORM, "dec.blk.%d.cross_attn_norm" },
1516+
{ LLM_TENSOR_DEC_CROSS_ATTN_Q, "dec.blk.%d.cross_attn_q" },
1517+
{ LLM_TENSOR_DEC_CROSS_ATTN_K, "dec.blk.%d.cross_attn_k" },
1518+
{ LLM_TENSOR_DEC_CROSS_ATTN_V, "dec.blk.%d.cross_attn_v" },
1519+
{ LLM_TENSOR_DEC_CROSS_ATTN_OUT, "dec.blk.%d.cross_attn_o" },
1520+
{ LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "dec.blk.%d.cross_attn_rel_b" },
1521+
{ LLM_TENSOR_DEC_FFN_NORM, "dec.blk.%d.ffn_norm" },
1522+
{ LLM_TENSOR_DEC_FFN_GATE, "dec.blk.%d.ffn_gate" },
1523+
{ LLM_TENSOR_DEC_FFN_DOWN, "dec.blk.%d.ffn_down" },
1524+
{ LLM_TENSOR_DEC_FFN_UP, "dec.blk.%d.ffn_up" },
1525+
{ LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
1526+
{ LLM_TENSOR_ENC_ATTN_NORM, "enc.blk.%d.attn_norm" },
1527+
{ LLM_TENSOR_ENC_ATTN_Q, "enc.blk.%d.attn_q" },
1528+
{ LLM_TENSOR_ENC_ATTN_K, "enc.blk.%d.attn_k" },
1529+
{ LLM_TENSOR_ENC_ATTN_V, "enc.blk.%d.attn_v" },
1530+
{ LLM_TENSOR_ENC_ATTN_OUT, "enc.blk.%d.attn_o" },
1531+
{ LLM_TENSOR_ENC_ATTN_REL_B, "enc.blk.%d.attn_rel_b" },
1532+
{ LLM_TENSOR_ENC_FFN_NORM, "enc.blk.%d.ffn_norm" },
1533+
{ LLM_TENSOR_ENC_FFN_GATE, "enc.blk.%d.ffn_gate" },
1534+
{ LLM_TENSOR_ENC_FFN_DOWN, "enc.blk.%d.ffn_down" },
1535+
{ LLM_TENSOR_ENC_FFN_UP, "enc.blk.%d.ffn_up" },
1536+
{ LLM_TENSOR_DEC_POST_SELF_ATTN_NORM, "dec.blk.%d.post_self_attn_norm" },
1537+
{ LLM_TENSOR_DEC_POST_CROSS_ATTN_NORM, "dec.blk.%d.post_cross_attn_norm" },
1538+
{ LLM_TENSOR_DEC_POST_FFN_NORM, "dec.blk.%d.post_ffn_norm" },
1539+
{ LLM_TENSOR_ENC_POST_SELF_ATTN_NORM, "enc.blk.%d.post_self_attn_norm" },
1540+
{ LLM_TENSOR_ENC_POST_FFN_NORM, "enc.blk.%d.post_ffn_norm" },
1541+
},
1542+
},
15021543
{
15031544
LLM_ARCH_JAIS,
15041545
{
@@ -2196,6 +2237,11 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
21962237
{LLM_TENSOR_ENC_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
21972238
{LLM_TENSOR_DEC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
21982239
{LLM_TENSOR_ENC_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
2240+
{LLM_TENSOR_DEC_POST_SELF_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2241+
{LLM_TENSOR_DEC_POST_CROSS_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2242+
{LLM_TENSOR_DEC_POST_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2243+
{LLM_TENSOR_ENC_POST_SELF_ATTN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
2244+
{LLM_TENSOR_ENC_POST_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
21992245
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
22002246
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
22012247
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},

src/llama-arch.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ enum llm_arch {
7070
LLM_ARCH_BITNET,
7171
LLM_ARCH_T5,
7272
LLM_ARCH_T5ENCODER,
73+
LLM_ARCH_T5GEMMA,
7374
LLM_ARCH_JAIS,
7475
LLM_ARCH_NEMOTRON,
7576
LLM_ARCH_EXAONE,
@@ -381,6 +382,12 @@ enum llm_tensor {
381382
LLM_TENSOR_DEC_FFN_DOWN,
382383
LLM_TENSOR_DEC_FFN_UP,
383384
LLM_TENSOR_DEC_OUTPUT_NORM,
385+
// T5GEMMA specific post layer normalization tensors
386+
LLM_TENSOR_DEC_POST_SELF_ATTN_NORM,
387+
LLM_TENSOR_DEC_POST_CROSS_ATTN_NORM,
388+
LLM_TENSOR_DEC_POST_FFN_NORM,
389+
LLM_TENSOR_ENC_POST_SELF_ATTN_NORM,
390+
LLM_TENSOR_ENC_POST_FFN_NORM,
384391
LLM_TENSOR_ENC_ATTN_NORM,
385392
LLM_TENSOR_ENC_ATTN_Q,
386393
LLM_TENSOR_ENC_ATTN_K,

src/llama-context.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,7 +915,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
915915
}
916916

917917
// TODO: hacky solution
918-
if (model.arch == LLM_ARCH_T5 && t_embd) {
918+
if ((model.arch == LLM_ARCH_T5 || model.arch == LLM_ARCH_T5GEMMA) && t_embd) {
919919
//cross.t_embd = t_embd;
920920

921921
synchronize();
@@ -1271,7 +1271,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
12711271
bool has_embd = cparams.embeddings;
12721272

12731273
// TODO: hacky enc-dec support
1274-
if (model.arch == LLM_ARCH_T5) {
1274+
if (model.arch == LLM_ARCH_T5 || model.arch == LLM_ARCH_T5GEMMA) {
12751275
has_logits = true;
12761276
has_embd = true;
12771277
}

0 commit comments

Comments
 (0)