Skip to content

Commit 2bc833d

Browse files
committed
Support suffix for fill-in-the-middle
1 parent 1d62372 commit 2bc833d

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

src/vllm_tgis_adapter/grpc/grpc_server.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from grpc_health.v1 import health, health_pb2, health_pb2_grpc
1515
from grpc_reflection.v1alpha import reflection
1616
from vllm.engine.async_llm_engine import AsyncLLMEngine
17+
from vllm.entrypoints.openai.fim import get_fim_encoder_lookup
1718
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
1819
from vllm.inputs import LLMInputs
1920
from vllm.sampling_params import RequestOutputKind, SamplingParams
@@ -205,6 +206,8 @@ def __init__(
205206
)
206207
self.health_servicer = health_servicer
207208

209+
self.get_fim_encoder = get_fim_encoder_lookup(args.fim)
210+
208211
async def post_init(self) -> None:
209212
self.config = await self.engine.get_model_config()
210213

@@ -250,7 +253,12 @@ async def Generate(
250253

251254
for i, req in enumerate(request.requests):
252255
input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize(
253-
sampling_params, truncate_input_tokens, req.text, tokenizer, context
256+
sampling_params,
257+
truncate_input_tokens,
258+
req.text,
259+
req.suffix,
260+
tokenizer,
261+
context,
254262
)
255263

256264
inputs = LLMInputs(
@@ -348,6 +356,7 @@ async def GenerateStream( # noqa: PLR0915, C901
348356
sampling_params,
349357
truncate_input_tokens,
350358
request.request.text,
359+
request.request.suffix,
351360
tokenizer,
352361
context,
353362
)
@@ -778,30 +787,47 @@ def _convert_tokens( # noqa: PLR0913
778787
)
779788
token_infos.append(token_info)
780789

781-
async def _validate_prompt_and_tokenize(
790+
async def _validate_prompt_and_tokenize( # noqa: PLR0913
782791
self,
783792
sampling_params: SamplingParams,
784793
truncate_input_tokens: int | None,
785794
prompt: str,
795+
suffix: str | None,
786796
tokenizer: AnyTokenizer,
787797
context: ServicerContext,
788798
) -> tuple[list[int], bool]:
789799
assert self.config is not None
790800

791-
max_model_len = self.config.max_model_len
792-
793-
tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": ADD_SPECIAL_TOKENS}
794-
if truncate_input_tokens is not None:
795-
tokenizer_kwargs.update(
796-
{
797-
"truncation": True,
798-
"max_length": truncate_input_tokens,
799-
}
800-
)
801+
if suffix:
802+
if not (get_fim_encoder := self.get_fim_encoder):
803+
await context.abort(
804+
StatusCode.INVALID_ARGUMENT,
805+
"fim support must be enabled to use suffix",
806+
)
807+
if truncate_input_tokens is not None:
808+
await context.abort(
809+
StatusCode.INVALID_ARGUMENT,
810+
"truncate_input_tokens cannot be used with suffix",
811+
)
812+
fim_encoder = get_fim_encoder(tokenizer)
813+
input_ids = fim_encoder.encode_with_suffix(prefix=prompt, suffix=suffix)
814+
else:
815+
tokenizer_kwargs: dict[str, Any] = {
816+
"add_special_tokens": ADD_SPECIAL_TOKENS
817+
}
818+
if truncate_input_tokens is not None:
819+
tokenizer_kwargs.update(
820+
{
821+
"truncation": True,
822+
"max_length": truncate_input_tokens,
823+
}
824+
)
801825

802-
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
826+
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
803827
token_num = len(input_ids)
804828

829+
max_model_len = self.config.max_model_len
830+
805831
try:
806832
validate_input(sampling_params, token_num, max_model_len)
807833
except ValueError as tgis_validation_error:

src/vllm_tgis_adapter/grpc/pb/generation.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ message BatchedGenerationResponse {
5151

5252
message GenerationRequest {
5353
string text = 2;
54+
/// Optional, for fill-in-middle
55+
string suffix = 3;
5456
}
5557

5658
message GenerationResponse {

0 commit comments

Comments
 (0)