@@ -480,7 +480,7 @@ def detokenize(
480480 Returns:
481481 The detokenized string.
482482 """
483- return self .tokenizer_ .detokenize (tokens , prev_tokens )
483+ return self .tokenizer_ .detokenize (tokens , prev_tokens = prev_tokens )
484484
485485 def set_cache (self , cache : Optional [BaseLlamaCache ]):
486486 """Set the cache.
@@ -1016,13 +1016,13 @@ def logit_bias_processor(
10161016 grammar = grammar ,
10171017 ):
10181018 if token == self ._token_eos :
1019- text = self .detokenize (completion_tokens )
1019+ text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
10201020 finish_reason = "stop"
10211021 break
10221022
10231023 completion_tokens .append (token )
10241024
1025- all_text = self .detokenize (completion_tokens )
1025+ all_text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
10261026
10271027 # Contains multi-byte UTF8
10281028 for k , char in enumerate (all_text [- 3 :]):
@@ -1046,7 +1046,7 @@ def logit_bias_processor(
10461046
10471047 if stream :
10481048 remaining_tokens = completion_tokens [returned_tokens :]
1049- remaining_text = self .detokenize (remaining_tokens )
1049+ remaining_text = self .detokenize (remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] )
10501050 remaining_length = len (remaining_text )
10511051
10521052 # We want to avoid yielding any characters from
@@ -1068,17 +1068,17 @@ def logit_bias_processor(
10681068 for token in remaining_tokens :
10691069 if token == self .token_bos ():
10701070 continue
1071- token_end_position += len (self .detokenize ([token ]))
1071+ token_end_position += len (self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ))
10721072 # Check if stop sequence is in the token
10731073 if token_end_position > (
10741074 remaining_length - first_stop_position
10751075 ):
10761076 break
1077- token_str = self .detokenize ([token ]).decode (
1077+ token_str = self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ).decode (
10781078 "utf-8" , errors = "ignore"
10791079 )
10801080 text_offset = len (prompt ) + len (
1081- self .detokenize (completion_tokens [:returned_tokens ]).decode (
1081+ self .detokenize (completion_tokens [:returned_tokens ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ).decode (
10821082 "utf-8" , errors = "ignore"
10831083 )
10841084 )
@@ -1100,7 +1100,7 @@ def logit_bias_processor(
11001100 top_logprob .update ({token_str : current_logprobs [int (token )]})
11011101 logprobs_or_none = {
11021102 "tokens" : [
1103- self .detokenize ([token ]).decode (
1103+ self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ).decode (
11041104 "utf-8" , errors = "ignore"
11051105 )
11061106 ],
@@ -1116,7 +1116,7 @@ def logit_bias_processor(
11161116 "model" : model_name ,
11171117 "choices" : [
11181118 {
1119- "text" : self .detokenize ([token ]).decode (
1119+ "text" : self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ).decode (
11201120 "utf-8" , errors = "ignore"
11211121 ),
11221122 "index" : 0 ,
@@ -1130,7 +1130,7 @@ def logit_bias_processor(
11301130 decode_success = False
11311131 for i in range (1 , len (remaining_tokens ) + 1 ):
11321132 try :
1133- bs = self .detokenize (remaining_tokens [:i ])
1133+ bs = self .detokenize (remaining_tokens [:i ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] )
11341134 ts = bs .decode ("utf-8" )
11351135 decode_success = True
11361136 break
@@ -1165,22 +1165,22 @@ def logit_bias_processor(
11651165 }
11661166
11671167 if len (completion_tokens ) >= max_tokens :
1168- text = self .detokenize (completion_tokens )
1168+ text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
11691169 finish_reason = "length"
11701170 break
11711171
11721172 if stopping_criteria is not None and stopping_criteria (
11731173 self ._input_ids , self ._scores [- 1 , :]
11741174 ):
1175- text = self .detokenize (completion_tokens )
1175+ text = self .detokenize (completion_tokens , prev_tokens = prompt_tokens )
11761176 finish_reason = "stop"
11771177
11781178 if self .verbose :
11791179 self ._ctx .print_timings ()
11801180
11811181 if stream :
11821182 remaining_tokens = completion_tokens [returned_tokens :]
1183- all_text = self .detokenize (remaining_tokens )
1183+ all_text = self .detokenize (remaining_tokens , prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] )
11841184 any_stop = [s for s in stop_sequences if s in all_text ]
11851185 if len (any_stop ) > 0 :
11861186 end = min (all_text .index (stop ) for stop in any_stop )
@@ -1189,7 +1189,7 @@ def logit_bias_processor(
11891189
11901190 token_end_position = 0
11911191 for token in remaining_tokens :
1192- token_end_position += len (self .detokenize ([token ]))
1192+ token_end_position += len (self .detokenize ([token ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] ))
11931193
11941194 logprobs_or_none : Optional [CompletionLogprobs ] = None
11951195 if logprobs is not None :
@@ -1199,7 +1199,7 @@ def logit_bias_processor(
11991199 "utf-8" , errors = "ignore"
12001200 )
12011201 text_offset = len (prompt ) + len (
1202- self .detokenize (completion_tokens [:returned_tokens ])
1202+ self .detokenize (completion_tokens [:returned_tokens ], prev_tokens = prompt_tokens + completion_tokens [: returned_tokens ] )
12031203 )
12041204 token_offset = len (prompt_tokens ) + returned_tokens - 1
12051205 logits = self ._scores [token_offset , :]
@@ -1313,8 +1313,8 @@ def logit_bias_processor(
13131313 all_tokens = completion_tokens
13141314
13151315 all_token_strs = [
1316- self .detokenize ([token ]).decode ("utf-8" , errors = "ignore" )
1317- for token in all_tokens
1316+ self .detokenize ([token ], prev_tokens = all_tokens [: i ] ).decode ("utf-8" , errors = "ignore" )
1317+ for i , token in enumerate ( all_tokens )
13181318 ]
13191319 all_logprobs = Llama .logits_to_logprobs (self ._scores )[token_offset :]
13201320 # TODO: may be able to change this loop to use np.take_along_dim
@@ -1339,7 +1339,7 @@ def logit_bias_processor(
13391339 )
13401340 token_logprobs .append (logprobs_token [int (token )])
13411341 top_logprob : Optional [Dict [str , float ]] = {
1342- self .detokenize ([i ]).decode ("utf-8" , errors = "ignore" ): logprob
1342+ self .detokenize ([i ], prev_tokens = all_tokens [: idx ] ).decode ("utf-8" , errors = "ignore" ): logprob
13431343 for logprob , i in sorted_logprobs [:logprobs ]
13441344 }
13451345 top_logprob .update ({token_str : logprobs_token [int (token )]})
@@ -1594,6 +1594,8 @@ def create_chat_completion(
15941594 logits_processor : Optional [LogitsProcessorList ] = None ,
15951595 grammar : Optional [LlamaGrammar ] = None ,
15961596 logit_bias : Optional [Dict [str , float ]] = None ,
1597+ logprobs : Optional [bool ] = None ,
1598+ top_logprobs : Optional [int ] = None ,
15971599 ) -> Union [
15981600 CreateChatCompletionResponse , Iterator [CreateChatCompletionStreamResponse ]
15991601 ]:
0 commit comments