|
14 | 14 | from grpc_health.v1 import health, health_pb2, health_pb2_grpc
|
15 | 15 | from grpc_reflection.v1alpha import reflection
|
16 | 16 | from vllm.engine.async_llm_engine import AsyncLLMEngine
|
| 17 | +from vllm.entrypoints.openai.fim import get_fim_encoder_lookup |
17 | 18 | from vllm.entrypoints.openai.serving_completion import merge_async_iterators
|
18 | 19 | from vllm.inputs import LLMInputs
|
19 | 20 | from vllm.sampling_params import RequestOutputKind, SamplingParams
|
@@ -205,6 +206,8 @@ def __init__(
|
205 | 206 | )
|
206 | 207 | self.health_servicer = health_servicer
|
207 | 208 |
|
| 209 | + self.get_fim_encoder = get_fim_encoder_lookup(args.fim) |
| 210 | + |
208 | 211 | async def post_init(self) -> None:
|
209 | 212 | self.config = await self.engine.get_model_config()
|
210 | 213 |
|
@@ -250,7 +253,12 @@ async def Generate(
|
250 | 253 |
|
251 | 254 | for i, req in enumerate(request.requests):
|
252 | 255 | input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize(
|
253 |
| - sampling_params, truncate_input_tokens, req.text, tokenizer, context |
| 256 | + sampling_params, |
| 257 | + truncate_input_tokens, |
| 258 | + req.text, |
| 259 | + req.suffix, |
| 260 | + tokenizer, |
| 261 | + context, |
254 | 262 | )
|
255 | 263 |
|
256 | 264 | inputs = LLMInputs(
|
@@ -348,6 +356,7 @@ async def GenerateStream( # noqa: PLR0915, C901
|
348 | 356 | sampling_params,
|
349 | 357 | truncate_input_tokens,
|
350 | 358 | request.request.text,
|
| 359 | + request.request.suffix, |
351 | 360 | tokenizer,
|
352 | 361 | context,
|
353 | 362 | )
|
@@ -778,30 +787,47 @@ def _convert_tokens( # noqa: PLR0913
|
778 | 787 | )
|
779 | 788 | token_infos.append(token_info)
|
780 | 789 |
|
781 |
| - async def _validate_prompt_and_tokenize( |
| 790 | + async def _validate_prompt_and_tokenize( # noqa: PLR0913 |
782 | 791 | self,
|
783 | 792 | sampling_params: SamplingParams,
|
784 | 793 | truncate_input_tokens: int | None,
|
785 | 794 | prompt: str,
|
| 795 | + suffix: str | None, |
786 | 796 | tokenizer: AnyTokenizer,
|
787 | 797 | context: ServicerContext,
|
788 | 798 | ) -> tuple[list[int], bool]:
|
789 | 799 | assert self.config is not None
|
790 | 800 |
|
791 |
| - max_model_len = self.config.max_model_len |
792 |
| - |
793 |
| - tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": ADD_SPECIAL_TOKENS} |
794 |
| - if truncate_input_tokens is not None: |
795 |
| - tokenizer_kwargs.update( |
796 |
| - { |
797 |
| - "truncation": True, |
798 |
| - "max_length": truncate_input_tokens, |
799 |
| - } |
800 |
| - ) |
| 801 | + if suffix: |
| 802 | + if not (get_fim_encoder := self.get_fim_encoder): |
| 803 | + await context.abort( |
| 804 | + StatusCode.INVALID_ARGUMENT, |
| 805 | + "fim support must be enabled to use suffix", |
| 806 | + ) |
| 807 | + if truncate_input_tokens is not None: |
| 808 | + await context.abort( |
| 809 | + StatusCode.INVALID_ARGUMENT, |
| 810 | + "truncate_input_tokens cannot be used with suffix", |
| 811 | + ) |
| 812 | + fim_encoder = get_fim_encoder(tokenizer) |
| 813 | + input_ids = fim_encoder.encode_with_suffix(prefix=prompt, suffix=suffix) |
| 814 | + else: |
| 815 | + tokenizer_kwargs: dict[str, Any] = { |
| 816 | + "add_special_tokens": ADD_SPECIAL_TOKENS |
| 817 | + } |
| 818 | + if truncate_input_tokens is not None: |
| 819 | + tokenizer_kwargs.update( |
| 820 | + { |
| 821 | + "truncation": True, |
| 822 | + "max_length": truncate_input_tokens, |
| 823 | + } |
| 824 | + ) |
801 | 825 |
|
802 |
| - input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids |
| 826 | + input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids |
803 | 827 | token_num = len(input_ids)
|
804 | 828 |
|
| 829 | + max_model_len = self.config.max_model_len |
| 830 | + |
805 | 831 | try:
|
806 | 832 | validate_input(sampling_params, token_num, max_model_len)
|
807 | 833 | except ValueError as tgis_validation_error:
|
|
0 commit comments