32
32
import triton_python_backend_utils as pb_utils
33
33
from transformers import AutoTokenizer , LlamaTokenizer , T5Tokenizer
34
34
35
+ INVALID_UNICODE_CHAR = "�"
36
+
35
37
36
38
class TritonPythonModel :
37
39
"""Your Python model must use the same class name. Every Python model
@@ -55,7 +57,8 @@ def initialize(self, args):
55
57
"""
56
58
# Parse model configs
57
59
model_config = json .loads (args ["model_config" ])
58
- tokenizer_dir = os .environ ["triton_tokenizer_repository" ]
60
+ # NOTE: Keep this in sync with the truss model.py variable
61
+ tokenizer_dir = os .environ ["TRITON_TOKENIZER_REPOSITORY" ]
59
62
tokenizer_type = model_config ["parameters" ]["tokenizer_type" ]["string_value" ]
60
63
61
64
if tokenizer_type == "t5" :
@@ -115,24 +118,48 @@ def execute(self, requests):
115
118
.as_numpy ()
116
119
.flatten ()
117
120
)
121
+
118
122
if len (tokens_batch ) == 0 :
119
123
continue
120
124
121
125
# Postprocess output data
122
- prev_token = self ._get_prev_token (request_id )
123
- self ._store_prev_token (request_id , tokens_batch [- 1 ])
126
+ prev_token = self ._get_var (request_id , "prev_token" )
127
+ token_buffer = self ._get_var (request_id , "token_buffer" )
128
+ token_buffer = token_buffer if token_buffer is not None else []
129
+ current_tokens = np .concatenate (
130
+ (np .array (token_buffer , dtype = int ), tokens_batch ), dtype = int
131
+ )
132
+ current_tokens_decoded = self .tokenizer .decode (current_tokens )
133
+
134
+ if len (current_tokens_decoded ) == 0 :
135
+ responses .append (pb_utils .InferenceResponse ())
136
+ continue
137
+
138
+ if current_tokens_decoded [- 1 ] == INVALID_UNICODE_CHAR :
139
+ # If the last token is invalid, we need to keep it in the buffer
140
+ # for the next request to see if this is a multi-token unicode
141
+ # character.
142
+ self ._store_var (request_id , "token_buffer" , current_tokens )
143
+ responses .append (pb_utils .InferenceResponse ())
144
+ continue
145
+
124
146
if prev_token is None :
125
- delta = self . tokenizer . decode ( tokens_batch )
147
+ delta = current_tokens_decoded
126
148
else :
127
149
# TODO(pankaj) Figure out how to make tokenizer.decode not
128
150
# ignore initial whitespace so we can avoid this hack.
129
151
# Get string with and without previous token and diff. This hack
130
152
# is needed because tokenizer.decode strips initial whitespace.
131
- old_string = self .tokenizer .decode ([ prev_token ] )
132
- with_prev_token = np .concatenate (([ prev_token ], tokens_batch ))
153
+ old_string = self .tokenizer .decode (prev_token )
154
+ with_prev_token = np .concatenate ((prev_token , current_tokens ))
133
155
new_string = self .tokenizer .decode (with_prev_token )
134
156
delta = self ._compute_delta (old_string , new_string )
135
157
158
+ # The previous token is the last character of the decoded sequence
159
+ # which includes the multi-token unicode character.
160
+ self ._store_var (request_id , "prev_token" , current_tokens )
161
+ self ._store_var (request_id , "token_buffer" , None )
162
+
136
163
# Create output tensor
137
164
output_tensor = pb_utils .Tensor (
138
165
"OUTPUT" , np .array ([delta ]).astype (self .output_dtype )
@@ -147,22 +174,19 @@ def execute(self, requests):
147
174
def finalize (self ):
148
175
print ("Cleaning up..." )
149
176
150
- def _store_prev_token (self , request_id , token ):
177
+ def _store_var (self , request_id , var_name , var ):
151
178
if request_id in self .state_dict :
152
- self .state_dict [request_id ]["prev_token" ] = token
153
-
154
- # Move request ID to end of queue to prevent it from being evicted
179
+ self .state_dict [request_id ][var_name ] = var
155
180
self .state_dict .move_to_end (request_id )
156
181
else :
157
- # Evict least recently used item if cache is full
158
182
if len (self .state_dict ) > self .cache_size :
159
183
self .state_dict .popitem (last = False )
184
+ self .state_dict [request_id ] = {"prev_token" : None , "token_buffer" : None }
185
+ self .state_dict [request_id ][var_name ] = var
160
186
161
- self .state_dict [request_id ] = {"prev_token" : token }
162
-
163
- def _get_prev_token (self , request_id ):
187
+ def _get_var (self , request_id , var_name ):
164
188
if request_id in self .state_dict :
165
- return self .state_dict [request_id ]["prev_token" ]
189
+ return self .state_dict [request_id ][var_name ]
166
190
return None
167
191
168
192
def _compute_delta (self , prev_str , new_str ):
0 commit comments