Skip to content

Commit 807b0c4

Browse files
fairydreamingsszymczyggerganov
authored
Inference support for T5 and FLAN-T5 model families (#5763)
* llama : add inference support and model types for T5 and FLAN-T5 model families * llama : add new API functions to support encoder-decoder models: llama_encode(), llama_model_has_encoder(), llama_model_decoder_start_token() * common, llama-cli, llama-batched : add support for encoder-decoder models * convert-hf : handle shared token embeddings tensors in T5Model * convert-hf : add support for SentencePiece BPE tokenizer in T5Model (for Pile-T5 models) * convert-hf : add MT5ForConditionalGeneration and UMT5ForConditionalGeneration to architectures supported by T5Model * convert : add t5 tokenizer tests, use "slow" HF tokenizer for t5 --------- Co-authored-by: Stanisław Szymczyk <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent f8c4c07 commit 807b0c4

33 files changed

+946
-31
lines changed

common/common.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2070,7 +2070,24 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
20702070
if (params.warmup) {
20712071
LOG("warming up the model with an empty run\n");
20722072

2073-
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
2073+
std::vector<llama_token> tmp;
2074+
llama_token bos = llama_token_bos(model);
2075+
llama_token eos = llama_token_eos(model);
2076+
// some models (e.g. T5) don't have a BOS token
2077+
if (bos != -1) {
2078+
tmp.push_back(bos);
2079+
}
2080+
tmp.push_back(eos);
2081+
2082+
if (llama_model_has_encoder(model)) {
2083+
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size(), 0, 0));
2084+
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
2085+
if (decoder_start_token_id == -1) {
2086+
decoder_start_token_id = bos;
2087+
}
2088+
tmp.clear();
2089+
tmp.push_back(decoder_start_token_id);
2090+
}
20742091
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
20752092
llama_kv_cache_clear(lctx);
20762093
llama_synchronize(lctx);

convert-hf-to-gguf-update.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class TOKENIZER_TYPE(IntEnum):
4545
SPM = auto()
4646
BPE = auto()
4747
WPM = auto()
48+
UGM = auto()
4849

4950

5051
# TODO: this string has to exercise as much pre-tokenizer functionality as possible
@@ -89,6 +90,7 @@ class TOKENIZER_TYPE(IntEnum):
8990
{"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", },
9091
{"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", },
9192
{"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", },
93+
{"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", },
9294
]
9395

9496

@@ -110,9 +112,13 @@ def download_model(model):
110112
os.makedirs(f"models/tokenizers/{name}", exist_ok=True)
111113

112114
files = ["config.json", "tokenizer.json", "tokenizer_config.json"]
115+
113116
if tokt == TOKENIZER_TYPE.SPM:
114117
files.append("tokenizer.model")
115118

119+
if tokt == TOKENIZER_TYPE.UGM:
120+
files.append("spiece.model")
121+
116122
for file in files:
117123
save_path = f"models/tokenizers/{name}/{file}"
118124
if os.path.isfile(save_path):
@@ -135,7 +141,7 @@ def download_model(model):
135141
name = model["name"]
136142
tokt = model["tokt"]
137143

138-
if tokt == TOKENIZER_TYPE.SPM:
144+
if tokt == TOKENIZER_TYPE.SPM or tokt == TOKENIZER_TYPE.UGM:
139145
continue
140146

141147
# Skip if the tokenizer folder does not exist or there are other download issues previously
@@ -145,7 +151,10 @@ def download_model(model):
145151

146152
# create the tokenizer
147153
try:
148-
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
154+
if name == "t5":
155+
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
156+
else:
157+
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
149158
except OSError as e:
150159
logger.error(f"Error loading tokenizer for model {name}. The model may not exist or is not accessible with the provided token. Error: {e}")
151160
continue # Skip to the next model if the tokenizer can't be loaded
@@ -266,6 +275,7 @@ def get_vocab_base_pre(self, tokenizer) -> str:
266275
"\n =",
267276
"' era",
268277
"Hello, y'all! How are you 😁 ?我想在apple工作1314151天~",
278+
"!!!!!!",
269279
"3",
270280
"33",
271281
"333",
@@ -304,7 +314,10 @@ def get_vocab_base_pre(self, tokenizer) -> str:
304314

305315
# create the tokenizer
306316
try:
307-
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
317+
if name == "t5":
318+
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}", use_fast=False)
319+
else:
320+
tokenizer = AutoTokenizer.from_pretrained(f"models/tokenizers/{name}")
308321
except OSError as e:
309322
logger.error(f"Failed to load tokenizer for model {name}. Error: {e}")
310323
continue # Skip this model and continue with the next one in the loop

convert-hf-to-gguf.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2853,29 +2853,47 @@ def write_tensors(self):
28532853
raise ValueError(f"Unprocessed experts: {experts}")
28542854

28552855

2856-
@Model.register("T5ForConditionalGeneration")
28572856
@Model.register("T5WithLMHeadModel")
2857+
@Model.register("T5ForConditionalGeneration")
2858+
@Model.register("MT5ForConditionalGeneration")
2859+
@Model.register("UMT5ForConditionalGeneration")
28582860
class T5Model(Model):
28592861
model_arch = gguf.MODEL_ARCH.T5
28602862

2863+
def __init__(self, *args, **kwargs):
2864+
super().__init__(*args, **kwargs)
2865+
self.shared_token_embeddings_found = False
2866+
28612867
def set_vocab(self):
28622868
# to avoid TypeError: Descriptors cannot be created directly
28632869
# exception when importing sentencepiece_model_pb2
28642870
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
28652871
from sentencepiece import SentencePieceProcessor
28662872
from sentencepiece import sentencepiece_model_pb2 as model
28672873

2868-
tokenizer_path = self.dir_model / 'spiece.model'
2874+
tokenizer_path = self.dir_model / 'tokenizer.model'
2875+
2876+
# many older models use spiece.model tokenizer model filename
2877+
if not tokenizer_path.is_file():
2878+
tokenizer_path = self.dir_model / 'spiece.model'
28692879

28702880
if not tokenizer_path.is_file():
28712881
raise FileNotFoundError(f"File not found: {tokenizer_path}")
28722882

28732883
sentencepiece_model = model.ModelProto()
28742884
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
2885+
2886+
# some models like Pile-T5 family use BPE tokenizer instead of Unigram
2887+
if sentencepiece_model.trainer_spec.model_type == 2: # BPE
2888+
# assure the tokenizer model file name is correct
2889+
assert tokenizer_path.name == 'tokenizer.model'
2890+
return self._set_vocab_sentencepiece()
2891+
else:
2892+
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
2893+
28752894
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
28762895
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
28772896
precompiled_charsmap = sentencepiece_model.normalizer_spec.precompiled_charsmap
2878-
assert sentencepiece_model.trainer_spec.model_type == 1 # UNIGRAM
28792897

28802898
tokenizer = SentencePieceProcessor()
28812899
tokenizer.LoadFromFile(str(tokenizer_path))
@@ -2945,7 +2963,10 @@ def set_vocab(self):
29452963

29462964
def set_gguf_parameters(self):
29472965
self.gguf_writer.add_name("T5")
2948-
self.gguf_writer.add_context_length(self.hparams["n_positions"])
2966+
if (n_ctx := self.find_hparam(["n_positions"], optional=True)) is None:
2967+
logger.warning("Couldn't find context length in config.json, assuming default value of 512")
2968+
n_ctx = 512
2969+
self.gguf_writer.add_context_length(n_ctx)
29492970
self.gguf_writer.add_embedding_length(self.hparams["d_model"])
29502971
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
29512972
self.gguf_writer.add_block_count(self.hparams["num_layers"])
@@ -2961,12 +2982,17 @@ def set_gguf_parameters(self):
29612982
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
29622983
del bid # unused
29632984

2964-
# Sometimes T5 and Flan-T5 based models contain "encoder.embed_tokens.weight" tensor or
2965-
# "decoder.embed_tokens.weight" tensors that are duplicates of "shared.weight" tensor
2966-
# To prevent errors caused by an unnecessary unmapped tensor, skip both of them and use only "shared.weight".
2967-
if name == "decoder.embed_tokens.weight" or name == "encoder.embed_tokens.weight":
2968-
logger.debug(f"Skipping tensor {name!r} in safetensors so that convert can end normally.")
2969-
return []
2985+
# T5 based models contain shared token embeddings tensors saved randomly as either "encoder.embed_tokens.weight",
2986+
# "decoder.embed_tokens.weight" or "shared.weight" tensor. In some models there are even multiple of them stored
2987+
# in the safetensors files. We use the first tensor from these three as the token embeddings for both encoder
2988+
# and decoder and ignore the remaining ones.
2989+
if name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight", "shared.weight"]:
2990+
if not self.shared_token_embeddings_found:
2991+
name = "shared.weight"
2992+
self.shared_token_embeddings_found = True
2993+
else:
2994+
logger.debug(f"Skipping shared tensor {name!r} in safetensors so that convert can end normally.")
2995+
return []
29702996

29712997
return [(self.map_tensor_name(name), data_torch)]
29722998

examples/batched/batched.cpp

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,34 @@ int main(int argc, char ** argv) {
9393

9494
// create a llama_batch
9595
// we use this object to submit token data for decoding
96-
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
96+
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
97+
98+
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
99+
for (int32_t i = 0; i < n_parallel; ++i) {
100+
seq_ids[i] = i;
101+
}
97102

98103
// evaluate the initial prompt
99104
for (size_t i = 0; i < tokens_list.size(); ++i) {
100-
llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
105+
llama_batch_add(batch, tokens_list[i], i, seq_ids, false);
101106
}
102107
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
103108

109+
if (llama_model_has_encoder(model)) {
110+
if (llama_encode(ctx, batch)) {
111+
LOG_TEE("%s : failed to eval\n", __func__);
112+
return 1;
113+
}
114+
115+
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
116+
if (decoder_start_token_id == -1) {
117+
decoder_start_token_id = llama_token_bos(model);
118+
}
119+
120+
llama_batch_clear(batch);
121+
llama_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
122+
}
123+
104124
// llama_decode will output logits only for the last token of the prompt
105125
batch.logits[batch.n_tokens - 1] = true;
106126

@@ -109,11 +129,11 @@ int main(int argc, char ** argv) {
109129
return 1;
110130
}
111131

112-
// assign the system KV cache to all parallel sequences
113-
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
114-
for (int32_t i = 1; i < n_parallel; ++i) {
115-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
116-
}
132+
//// assign the system KV cache to all parallel sequences
133+
//// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
134+
//for (int32_t i = 1; i < n_parallel; ++i) {
135+
// llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
136+
//}
117137

118138
if (n_parallel > 1) {
119139
LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);

examples/main/main.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,9 @@ int main(int argc, char ** argv) {
255255
}
256256

257257
const bool add_bos = llama_should_add_bos_token(model);
258-
GGML_ASSERT(llama_add_eos_token(model) != 1);
258+
if (!llama_model_has_encoder(model)) {
259+
GGML_ASSERT(llama_add_eos_token(model) != 1);
260+
}
259261
LOG("add_bos: %d\n", add_bos);
260262

261263
std::vector<llama_token> embd_inp;
@@ -517,6 +519,24 @@ int main(int argc, char ** argv) {
517519
exit(1);
518520
}
519521

522+
if (llama_model_has_encoder(model)) {
523+
int enc_input_size = embd_inp.size();
524+
llama_token * enc_input_buf = embd_inp.data();
525+
526+
if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size, 0, 0))) {
527+
LOG_TEE("%s : failed to eval\n", __func__);
528+
return 1;
529+
}
530+
531+
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
532+
if (decoder_start_token_id == -1) {
533+
decoder_start_token_id = llama_token_bos(model);
534+
}
535+
536+
embd_inp.clear();
537+
embd_inp.push_back(decoder_start_token_id);
538+
}
539+
520540
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
521541
// predict
522542
if (!embd.empty()) {

include/llama.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,13 @@ extern "C" {
485485
// Get a llama model tensor
486486
LLAMA_API struct ggml_tensor * llama_get_model_tensor(struct llama_model * model, const char * name);
487487

488+
// Returns true if the model contains an encoder that requires llama_encode() call
489+
LLAMA_API bool llama_model_has_encoder(const struct llama_model * model);
490+
491+
// For encoder-decoder models, this function returns id of the token that must be provided
492+
// to the decoder to start generating output sequence. For other models, it returns -1.
493+
LLAMA_API llama_token llama_model_decoder_start_token(const struct llama_model * model);
494+
488495
// Returns 0 on success
489496
LLAMA_API uint32_t llama_model_quantize(
490497
const char * fname_inp,
@@ -770,6 +777,14 @@ extern "C" {
770777
// Frees a batch of tokens allocated with llama_batch_init()
771778
LLAMA_API void llama_batch_free(struct llama_batch batch);
772779

780+
// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
781+
// Stores the encoder output internally for later use by the decoder cross-attention layers.
782+
// 0 - success
783+
// < 0 - error
784+
LLAMA_API int32_t llama_encode(
785+
struct llama_context * ctx,
786+
struct llama_batch batch);
787+
773788
// Positive return values does not mean a fatal error, but rather a warning.
774789
// 0 - success
775790
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)

models/ggml-vocab-bert-bge.gguf.inp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ __ggml_vocab_test__
7373
__ggml_vocab_test__
7474
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
7575
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
7678
3
7779
__ggml_vocab_test__
7880
33

models/ggml-vocab-bert-bge.gguf.out

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
1027
3232
1005 3690
3333
7592 1010 1061 1005 2035 999 2129 2024 2017 100 1029 1855 100 100 6207 100 100 14677 23632 22203 1811 1995
34+
999 999 999 999 999 999
3435
1017
3536
3943
3637
21211

models/ggml-vocab-command-r.gguf.inp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ __ggml_vocab_test__
7373
__ggml_vocab_test__
7474
Hello, y'all! How are you 😁 ?我想在apple工作1314151天~
7575
__ggml_vocab_test__
76+
!!!!!!
77+
__ggml_vocab_test__
7678
3
7779
__ggml_vocab_test__
7880
33

models/ggml-vocab-command-r.gguf.out

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
206 1857
3232
14 4515
3333
28339 19 1770 14 1954 8 4070 1955 1933 80503 231 5691 12081 13336 2648 29325 14315 24 26 24 27 24 28 24 5123 18372
34+
57178 10251
3435
26
3536
26 26
3637
26 26 26

0 commit comments

Comments
 (0)