Skip to content

Commit 0f9ed68

Browse files
authored
flan t5 trt model (#221)
* Doesn't support batching * beam width is fixed at model level in the truss config
1 parent 9fd4602 commit 0f9ed68

File tree

4 files changed

+591
-0
lines changed

4 files changed

+591
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
base_image:
2+
image: docker.io/baseten/trtllm-server:r23.12_baseten_v0.9.0.dev2024022000
3+
python_executable_path: /usr/bin/python3
4+
description: Flan T5 finetuned
5+
environment_variables:
6+
HF_HUB_ENABLE_HF_TRANSFER: true
7+
model_metadata:
8+
beam_width: 1
9+
engine_repository: baseten/flan-t5-large-trt-engine
10+
engine_name: flan-t5-large
11+
model_name: flan t5 large tensorrt-llm
12+
requirements:
13+
- hf_transfer
14+
resources:
15+
accelerator: L4
16+
use_gpu: true
17+
runtime:
18+
predict_concurrency: 1

tensorrt-llm/flan-t5-trt-llm/model/__init__.py

Whitespace-only changes.
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import torch
2+
from enc_dec.enc_dec_model import TRTLLMEncDecModel
3+
from huggingface_hub import snapshot_download
4+
from transformers import AutoConfig, AutoTokenizer
5+
6+
HF_MODEL_NAME = "google-t5/t5-large"
7+
DEFAULT_MAX_NEW_TOKENS = 20
8+
9+
10+
class Model:
11+
def __init__(self, **kwargs):
12+
self._engine_dir = str(kwargs["data_dir"])
13+
model_metadata = kwargs["config"]["model_metadata"]
14+
self._engine_repo = model_metadata["engine_repository"]
15+
self._engine_name = model_metadata["engine_name"]
16+
self._beam_width = model_metadata["beam_width"]
17+
18+
def load(self):
19+
snapshot_download(repo_id=self._engine_repo, local_dir=self._engine_dir)
20+
self._tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
21+
model_config = AutoConfig.from_pretrained(HF_MODEL_NAME)
22+
self._decoder_start_token_id = model_config.decoder_start_token_id
23+
self._tllm_model = TRTLLMEncDecModel.from_engine(
24+
self._engine_name, self._engine_dir
25+
)
26+
27+
def predict(self, model_input):
28+
try:
29+
input_text = model_input.pop("prompt")
30+
max_new_tokens = model_input.pop("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
31+
32+
tokenized_inputs = self._tokenizer(
33+
input_text, return_tensors="pt", padding=True
34+
)
35+
input_ids = tokenized_inputs.input_ids.type(torch.IntTensor).to("cuda")
36+
decoder_input_ids = torch.IntTensor([[self._decoder_start_token_id]]).to(
37+
"cuda"
38+
)
39+
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
40+
41+
tllm_output = self._tllm_model.generate(
42+
encoder_input_ids=input_ids,
43+
decoder_input_ids=decoder_input_ids,
44+
max_new_tokens=max_new_tokens,
45+
num_beams=self._beam_width,
46+
bos_token_id=self._tokenizer.bos_token_id,
47+
pad_token_id=self._tokenizer.pad_token_id,
48+
eos_token_id=self._tokenizer.eos_token_id,
49+
return_dict=True,
50+
attention_mask=tokenized_inputs.attention_mask,
51+
)
52+
tllm_output_ids = tllm_output["output_ids"]
53+
decoded_output = []
54+
for i in range(self._beam_width):
55+
output_ids = tllm_output_ids[:, i, :]
56+
output_text = self._tokenizer.batch_decode(
57+
output_ids, skip_special_tokens=True
58+
)
59+
decoded_output.append(output_text)
60+
return {"status": "success", "data": decoded_output}
61+
except Exception as exc:
62+
return {"status": "error", "data": None, "message": str(exc)}

0 commit comments

Comments
 (0)