-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy patht5-client-triton-ft.py
More file actions
110 lines (99 loc) · 3.66 KB
/
t5-client-triton-ft.py
File metadata and controls
110 lines (99 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import argparse
import os
import sys
from datetime import datetime
import numpy as np
import torch
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(dir_path + "/../")
from transformers import PreTrainedTokenizerFast
from transformers import (
T5ForConditionalGeneration,
T5Tokenizer,
) # transformers-4.10.0-py3
import tritonclient.grpc as grpcclient
import tritonclient.http as httpclient
from tritonclient.utils import np_to_triton_dtype
def translate(args_dict):
torch.set_printoptions(precision=6)
batch_size = args_dict["batch_size"]
#t5_model = T5ForConditionalGeneration.from_pretrained(args_dict["model"])
tokenizer = T5Tokenizer.from_pretrained(args_dict["model"])
fast_tokenizer = PreTrainedTokenizerFast.from_pretrained(args_dict["model"])
client_util = httpclient
url = "nvdl-smc-02:8000"
model_name = "fastertransformer"
request_parallelism = 10
verbose = False
client = client_util.InferenceServerClient(
url, concurrency=request_parallelism, verbose=verbose
)
t5_task_input = None
while True:
t5_task_input = input(
"what do you want to do on T5?, format <task>: <sequence>, use exit to escape\n\n"
)
if t5_task_input == "exit":
return
else:
print(t5_task_input)
sys.stdout.flush()
input_token = tokenizer(t5_task_input, return_tensors="pt", padding=True)
input_ids = input_token.input_ids.numpy().astype(np.uint32)
mem_seq_len = (
torch.sum(input_token.attention_mask, dim=1).numpy().astype(np.uint32)
)
mem_seq_len = mem_seq_len.reshape([mem_seq_len.shape[0], 1])
inputs = [
client_util.InferInput(
"INPUT_ID", input_ids.shape, np_to_triton_dtype(input_ids.dtype)
),
client_util.InferInput(
"REQUEST_INPUT_LEN",
mem_seq_len.shape,
np_to_triton_dtype(mem_seq_len.dtype),
),
]
inputs[0].set_data_from_numpy(input_ids)
inputs[1].set_data_from_numpy(mem_seq_len)
print("sent request\n")
result = client.infer(model_name, inputs)
print("get request\n")
ft_decoding_outputs = result.as_numpy("OUTPUT0")
ft_decoding_seq_lens = result.as_numpy("OUTPUT1")
# print(type(ft_decoding_outputs), type(ft_decoding_seq_lens))
# print(ft_decoding_outputs, ft_decoding_seq_lens)
tokens = fast_tokenizer.decode(
ft_decoding_outputs[0][0][: ft_decoding_seq_lens[0][0]],
skip_special_tokens=True,
)
print(tokens)
print("\n")
"""print("output from T5 model using HF library:\n")
input_ids = tokenizer(t5_task_input, return_tensors="pt").input_ids
outputs = t5_model.generate(input_ids)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
print("\n")"""
if __name__ == "__main__":
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-batch",
"--batch_size",
type=int,
default=1,
metavar="NUMBER",
help="batch size (default: 1)",
)
parser.add_argument(
"-model",
"--model",
type=str,
default="t5-base",
metavar="STRING",
help="T5 model size.",
choices=["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"],
)
args = parser.parse_args()
translate(vars(args))