|
8 | 8 | from transformers import (AutoTokenizer, PreTrainedTokenizer, |
9 | 9 | PreTrainedTokenizerFast) |
10 | 10 |
|
11 | | -from vllm.inputs import token_inputs |
12 | | -from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup |
13 | | -from vllm.transformers_utils.detokenizer import Detokenizer |
14 | | -from vllm.transformers_utils.tokenizer import get_tokenizer |
| 11 | +from vllm.sampling_params import SamplingParams |
15 | 12 | from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer |
16 | 13 | from vllm.v1.engine import EngineCoreRequest |
17 | 14 | from vllm.v1.engine.detokenizer import (FastIncrementalDetokenizer, |
@@ -217,138 +214,3 @@ def test_oov_decode(tokenizer, fast): |
217 | 214 |
|
218 | 215 | assert decoded_text == '' |
219 | 216 | assert out_ids == [len(tokenizer)] |
220 | | - |
221 | | - |
222 | | -@pytest.fixture |
223 | | -def detokenizer(tokenizer_name: str) -> Detokenizer: |
224 | | - tokenizer = get_tokenizer( |
225 | | - tokenizer_name, |
226 | | - tokenizer_mode="mistral" if "mistral" in tokenizer_name else "auto", |
227 | | - trust_remote_code=False, |
228 | | - revision=None, |
229 | | - ) |
230 | | - |
231 | | - return Detokenizer(tokenizer) |
232 | | - |
233 | | - |
234 | | -@pytest.fixture(name="complete_sequence_token_ids") |
235 | | -def create_complete_sequence_token_ids(complete_sequence: str, |
236 | | - tokenizer) -> list[int]: |
237 | | - return tokenizer(complete_sequence, add_special_tokens=False).input_ids |
238 | | - |
239 | | - |
240 | | -def create_sequence(prompt_token_ids=None): |
241 | | - prompt_token_ids = prompt_token_ids or [] |
242 | | - return Sequence( |
243 | | - seq_id=0, |
244 | | - inputs=token_inputs(prompt_token_ids), |
245 | | - block_size=16, |
246 | | - ) |
247 | | - |
248 | | - |
249 | | -def create_dummy_logprobs( |
250 | | - complete_sequence_token_ids: list[int]) -> list[dict[int, Logprob]]: |
251 | | - return [{ |
252 | | - token_id: Logprob(logprob=0.0), |
253 | | - token_id + 1: Logprob(logprob=0.1) |
254 | | - } for token_id in complete_sequence_token_ids] |
255 | | - |
256 | | - |
257 | | -def create_dummy_prompt_logprobs( |
258 | | - complete_sequence_token_ids: list[int] |
259 | | -) -> list[Optional[dict[int, Any]]]: |
260 | | - # logprob for the first prompt token is None. |
261 | | - logprobs: list[Optional[dict[int, Any]]] = [None] |
262 | | - logprobs.extend(create_dummy_logprobs(complete_sequence_token_ids)[1:]) |
263 | | - return logprobs |
264 | | - |
265 | | - |
266 | | -@pytest.mark.parametrize("complete_sequence", TRUTH) |
267 | | -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) |
268 | | -@pytest.mark.parametrize("skip_special_tokens", [True, False], indirect=True) |
269 | | -def test_decode_sequence_logprobs(complete_sequence: str, |
270 | | - complete_sequence_token_ids: list[int], |
271 | | - detokenizer: Detokenizer, |
272 | | - skip_special_tokens: bool): |
273 | | - """Verify Detokenizer decodes logprobs correctly.""" |
274 | | - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, |
275 | | - logprobs=2) |
276 | | - |
277 | | - # Run sequentially. |
278 | | - seq = create_sequence() |
279 | | - dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids) |
280 | | - sequential_logprobs_text_chosen_token: list[str] = [] |
281 | | - sequential_logprobs_text_other_token: list[str] = [] |
282 | | - for new_token, logprobs in zip(complete_sequence_token_ids, |
283 | | - dummy_logprobs): |
284 | | - seq.append_token_id(new_token, logprobs) |
285 | | - detokenizer.decode_sequence_inplace(seq, sampling_params) |
286 | | - sequential_logprobs_text_chosen_token.append( |
287 | | - seq.output_logprobs[-1][new_token].decoded_token) |
288 | | - sequential_logprobs_text_other_token.append( |
289 | | - seq.output_logprobs[-1][new_token + 1].decoded_token) |
290 | | - sequential_result = seq.output_text |
291 | | - |
292 | | - assert sequential_result == "".join(sequential_logprobs_text_chosen_token) |
293 | | - assert sequential_result != "".join(sequential_logprobs_text_other_token) |
294 | | - |
295 | | - if not skip_special_tokens: |
296 | | - # Text for logprobs for the chosen token should be the same as the |
297 | | - # generated text. Note that this will only be true if we skip |
298 | | - # special tokens. |
299 | | - assert sequential_result == complete_sequence |
300 | | - |
301 | | - |
302 | | -@pytest.mark.parametrize("complete_sequence", TRUTH) |
303 | | -@pytest.mark.parametrize("tokenizer_name", TOKENIZERS) |
304 | | -def test_decode_prompt_logprobs(complete_sequence: str, |
305 | | - complete_sequence_token_ids: list[int], |
306 | | - detokenizer: Detokenizer): |
307 | | - |
308 | | - # We want to use skip_special_tokens=False here but Mistral tokenizers |
309 | | - # don't support that. |
310 | | - if complete_sequence not in SPECIAL_TOKS_TRUTH: |
311 | | - skip_special_tokens = True |
312 | | - elif not isinstance(detokenizer.tokenizer, MistralTokenizer): |
313 | | - skip_special_tokens = False |
314 | | - else: |
315 | | - pytest.skip("MistralTokenizers don't support " |
316 | | - "skip_special_tokens=False") |
317 | | - return |
318 | | - """Verify Detokenizer decodes prompt logprobs correctly.""" |
319 | | - sampling_params = SamplingParams(skip_special_tokens=skip_special_tokens, |
320 | | - prompt_logprobs=1) |
321 | | - |
322 | | - # Run sequentially. |
323 | | - seq = create_sequence(complete_sequence_token_ids) |
324 | | - seq_group = SequenceGroup(request_id="1", |
325 | | - seqs=[seq], |
326 | | - sampling_params=sampling_params, |
327 | | - arrival_time=0.0) |
328 | | - dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids) |
329 | | - detokenizer.decode_prompt_logprobs_inplace(seq_group, |
330 | | - dummy_logprobs, |
331 | | - position_offset=0) |
332 | | - # First logprob is None. |
333 | | - decoded_prompt_logprobs: list[dict[int, Any]] = dummy_logprobs[ |
334 | | - 1:] # type: ignore |
335 | | - |
336 | | - # decoded_prompt_logprobs doesn't contain the first token. |
337 | | - token_ids = complete_sequence_token_ids |
338 | | - tokenizer = detokenizer.tokenizer |
339 | | - text_full = tokenizer.decode(token_ids, |
340 | | - skip_special_tokens=skip_special_tokens) |
341 | | - text_first = tokenizer.decode(token_ids[0], |
342 | | - skip_special_tokens=skip_special_tokens) |
343 | | - text = text_full[len(text_first):] |
344 | | - |
345 | | - # Text for logprobs for the chosen token should be the same as the |
346 | | - # prompt text. Note that the first logprob is None. |
347 | | - assert text == "".join([ |
348 | | - logprobs[token_id].decoded_token |
349 | | - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) |
350 | | - ]) |
351 | | - assert text != "".join([ |
352 | | - logprobs[token_id + 1].decoded_token |
353 | | - for token_id, logprobs in zip(token_ids[1:], decoded_prompt_logprobs) |
354 | | - ]) |
0 commit comments