|
15 | 15 | from grpc_reflection.v1alpha import reflection
|
16 | 16 | from vllm.engine.async_llm_engine import AsyncLLMEngine
|
17 | 17 | from vllm.engine.multiprocessing import MQEngineDeadError
|
| 18 | +from vllm.entrypoints.openai.fim import get_fim_encoder_lookup |
18 | 19 | from vllm.entrypoints.openai.serving_completion import merge_async_iterators
|
19 | 20 | from vllm.inputs import LLMInputs
|
20 | 21 | from vllm.sampling_params import RequestOutputKind, SamplingParams
|
@@ -209,6 +210,8 @@ def __init__(
|
209 | 210 | )
|
210 | 211 | self.health_servicer = health_servicer
|
211 | 212 |
|
| 213 | + self.get_fim_encoder = get_fim_encoder_lookup(args.fim) |
| 214 | + |
212 | 215 | async def post_init(self) -> None:
|
213 | 216 | self.config = await self.engine.get_model_config()
|
214 | 217 |
|
@@ -254,7 +257,12 @@ async def Generate(
|
254 | 257 |
|
255 | 258 | for i, req in enumerate(request.requests):
|
256 | 259 | input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize(
|
257 |
| - sampling_params, truncate_input_tokens, req.text, tokenizer, context |
| 260 | + sampling_params, |
| 261 | + truncate_input_tokens, |
| 262 | + req.text, |
| 263 | + req.suffix, |
| 264 | + tokenizer, |
| 265 | + context, |
258 | 266 | )
|
259 | 267 |
|
260 | 268 | inputs = LLMInputs(
|
@@ -352,6 +360,7 @@ async def GenerateStream( # noqa: PLR0915, C901
|
352 | 360 | sampling_params,
|
353 | 361 | truncate_input_tokens,
|
354 | 362 | request.request.text,
|
| 363 | + request.request.suffix, |
355 | 364 | tokenizer,
|
356 | 365 | context,
|
357 | 366 | )
|
@@ -782,30 +791,47 @@ def _convert_tokens( # noqa: PLR0913
|
782 | 791 | )
|
783 | 792 | token_infos.append(token_info)
|
784 | 793 |
|
785 |
| - async def _validate_prompt_and_tokenize( |
| 794 | + async def _validate_prompt_and_tokenize( # noqa: PLR0913 |
786 | 795 | self,
|
787 | 796 | sampling_params: SamplingParams,
|
788 | 797 | truncate_input_tokens: int | None,
|
789 | 798 | prompt: str,
|
| 799 | + suffix: str | None, |
790 | 800 | tokenizer: AnyTokenizer,
|
791 | 801 | context: ServicerContext,
|
792 | 802 | ) -> tuple[list[int], bool]:
|
793 | 803 | assert self.config is not None
|
794 | 804 |
|
795 |
| - max_model_len = self.config.max_model_len |
796 |
| - |
797 |
| - tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": ADD_SPECIAL_TOKENS} |
798 |
| - if truncate_input_tokens is not None: |
799 |
| - tokenizer_kwargs.update( |
800 |
| - { |
801 |
| - "truncation": True, |
802 |
| - "max_length": truncate_input_tokens, |
803 |
| - } |
804 |
| - ) |
| 805 | + if suffix: |
| 806 | + if not (get_fim_encoder := self.get_fim_encoder): |
| 807 | + await context.abort( |
| 808 | + StatusCode.INVALID_ARGUMENT, |
| 809 | + "fim support must be enabled to use suffix", |
| 810 | + ) |
| 811 | + if truncate_input_tokens is not None: |
| 812 | + await context.abort( |
| 813 | + StatusCode.INVALID_ARGUMENT, |
| 814 | + "truncate_input_tokens cannot be used with suffix", |
| 815 | + ) |
| 816 | + fim_encoder = get_fim_encoder(tokenizer) |
| 817 | + input_ids = fim_encoder.encode_with_suffix(prefix=prompt, suffix=suffix) |
| 818 | + else: |
| 819 | + tokenizer_kwargs: dict[str, Any] = { |
| 820 | + "add_special_tokens": ADD_SPECIAL_TOKENS |
| 821 | + } |
| 822 | + if truncate_input_tokens is not None: |
| 823 | + tokenizer_kwargs.update( |
| 824 | + { |
| 825 | + "truncation": True, |
| 826 | + "max_length": truncate_input_tokens, |
| 827 | + } |
| 828 | + ) |
805 | 829 |
|
806 |
| - input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids |
| 830 | + input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids |
807 | 831 | token_num = len(input_ids)
|
808 | 832 |
|
| 833 | + max_model_len = self.config.max_model_len |
| 834 | + |
809 | 835 | try:
|
810 | 836 | validate_input(sampling_params, token_num, max_model_len)
|
811 | 837 | except ValueError as tgis_validation_error:
|
|
0 commit comments