1
+ import time
2
+
1
3
import torch
2
4
from enc_dec .enc_dec_model import TRTLLMEncDecModel
3
5
from huggingface_hub import snapshot_download
4
6
from transformers import AutoConfig , AutoTokenizer
5
7
6
- HF_MODEL_NAME = "google-t5/ t5-large"
8
+ HF_MODEL_NAME = "google/flan- t5-large"
7
9
DEFAULT_MAX_NEW_TOKENS = 20
8
10
9
11
@@ -14,9 +16,17 @@ def __init__(self, **kwargs):
14
16
self ._engine_repo = model_metadata ["engine_repository" ]
15
17
self ._engine_name = model_metadata ["engine_name" ]
16
18
self ._beam_width = model_metadata ["beam_width" ]
19
+ self ._secrets = kwargs ["secrets" ]
20
+ self ._hf_access_token = self ._secrets ["hf_access_token" ]
21
+ if not self ._hf_access_token :
22
+ self ._hf_access_token = None
17
23
18
24
def load (self ):
19
- snapshot_download (repo_id = self ._engine_repo , local_dir = self ._engine_dir )
25
+ snapshot_download (
26
+ repo_id = self ._engine_repo ,
27
+ local_dir = self ._engine_dir ,
28
+ token = self ._hf_access_token ,
29
+ )
20
30
self ._tokenizer = AutoTokenizer .from_pretrained (HF_MODEL_NAME )
21
31
model_config = AutoConfig .from_pretrained (HF_MODEL_NAME )
22
32
self ._decoder_start_token_id = model_config .decoder_start_token_id
@@ -25,6 +35,7 @@ def load(self):
25
35
)
26
36
27
37
def predict (self , model_input ):
38
+ start_time = time .time ()
28
39
try :
29
40
input_text = model_input .pop ("prompt" )
30
41
max_new_tokens = model_input .pop ("max_new_tokens" , DEFAULT_MAX_NEW_TOKENS )
@@ -57,6 +68,7 @@ def predict(self, model_input):
57
68
output_ids , skip_special_tokens = True
58
69
)
59
70
decoded_output .append (output_text )
71
+ print (f"Inference time: { (time .time () - start_time )* 1000 } ms" )
60
72
return {"status" : "success" , "data" : decoded_output }
61
73
except Exception as exc :
62
74
return {"status" : "error" , "data" : None , "message" : str (exc )}
0 commit comments