Skip to content

Commit 61280fe

Browse files
authored
flat t5 trtllm follow up (#222)
* Stop early -- added a stopping criteria for this, couldn't find another way * Print inference time * Use hf_access_token from secrets
1 parent c7d2151 commit 61280fe

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

tensorrt-llm/flan-t5-trt-llm/config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@ resources:
1616
use_gpu: true
1717
runtime:
1818
predict_concurrency: 1
19+
secrets:
20+
hf_access_token: placeholder__bound_at_runtime

tensorrt-llm/flan-t5-trt-llm/model/model.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import time
2+
13
import torch
24
from enc_dec.enc_dec_model import TRTLLMEncDecModel
35
from huggingface_hub import snapshot_download
46
from transformers import AutoConfig, AutoTokenizer
57

6-
HF_MODEL_NAME = "google-t5/t5-large"
8+
HF_MODEL_NAME = "google/flan-t5-large"
79
DEFAULT_MAX_NEW_TOKENS = 20
810

911

@@ -14,9 +16,17 @@ def __init__(self, **kwargs):
1416
self._engine_repo = model_metadata["engine_repository"]
1517
self._engine_name = model_metadata["engine_name"]
1618
self._beam_width = model_metadata["beam_width"]
19+
self._secrets = kwargs["secrets"]
20+
self._hf_access_token = self._secrets["hf_access_token"]
21+
if not self._hf_access_token:
22+
self._hf_access_token = None
1723

1824
def load(self):
19-
snapshot_download(repo_id=self._engine_repo, local_dir=self._engine_dir)
25+
snapshot_download(
26+
repo_id=self._engine_repo,
27+
local_dir=self._engine_dir,
28+
token=self._hf_access_token,
29+
)
2030
self._tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_NAME)
2131
model_config = AutoConfig.from_pretrained(HF_MODEL_NAME)
2232
self._decoder_start_token_id = model_config.decoder_start_token_id
@@ -25,6 +35,7 @@ def load(self):
2535
)
2636

2737
def predict(self, model_input):
38+
start_time = time.time()
2839
try:
2940
input_text = model_input.pop("prompt")
3041
max_new_tokens = model_input.pop("max_new_tokens", DEFAULT_MAX_NEW_TOKENS)
@@ -57,6 +68,7 @@ def predict(self, model_input):
5768
output_ids, skip_special_tokens=True
5869
)
5970
decoded_output.append(output_text)
71+
print(f"Inference time: {(time.time() - start_time)*1000}ms")
6072
return {"status": "success", "data": decoded_output}
6173
except Exception as exc:
6274
return {"status": "error", "data": None, "message": str(exc)}

tensorrt-llm/flan-t5-trt-llm/packages/enc_dec/enc_dec_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,15 @@ def generate(
497497
)
498498
torch.cuda.synchronize()
499499

500+
# TODO(pankaj) Figure out a better way to stop this.
501+
# Using stopping criteria is expensive, but there couldn't find
502+
# another way of stopping generation early.
503+
def stopping_criteria(
504+
step: int, input_ids: torch.Tensor, scores: torch.Tensor
505+
) -> bool:
506+
# If generated token is eos then stop
507+
return input_ids[0][step + 1] == eos_token_id
508+
500509
output = self.decoder_session.decode(
501510
decoder_input_ids,
502511
decoder_input_lengths,
@@ -505,6 +514,7 @@ def generate(
505514
encoder_input_lengths=encoder_input_lengths,
506515
return_dict=return_dict,
507516
cross_attention_mask=cross_attention_mask,
517+
stopping_criteria=stopping_criteria,
508518
)
509519
torch.cuda.synchronize()
510520

0 commit comments

Comments
 (0)