Skip to content

Commit b9f6005

Browse files
committed
Support suffix for fill-in-the-middle
1 parent 2c80c72 commit b9f6005

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

src/vllm_tgis_adapter/grpc/grpc_server.py

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from grpc_reflection.v1alpha import reflection
1616
from vllm.engine.async_llm_engine import AsyncLLMEngine
1717
from vllm.engine.multiprocessing import MQEngineDeadError
18+
from vllm.entrypoints.openai.fim import get_fim_encoder_lookup
1819
from vllm.entrypoints.openai.serving_completion import merge_async_iterators
1920
from vllm.inputs import LLMInputs
2021
from vllm.sampling_params import RequestOutputKind, SamplingParams
@@ -209,6 +210,8 @@ def __init__(
209210
)
210211
self.health_servicer = health_servicer
211212

213+
self.get_fim_encoder = get_fim_encoder_lookup(args.fim)
214+
212215
async def post_init(self) -> None:
213216
self.config = await self.engine.get_model_config()
214217

@@ -254,7 +257,12 @@ async def Generate(
254257

255258
for i, req in enumerate(request.requests):
256259
input_ids, max_is_token_limit[i] = await self._validate_prompt_and_tokenize(
257-
sampling_params, truncate_input_tokens, req.text, tokenizer, context
260+
sampling_params,
261+
truncate_input_tokens,
262+
req.text,
263+
req.suffix,
264+
tokenizer,
265+
context,
258266
)
259267

260268
inputs = LLMInputs(
@@ -352,6 +360,7 @@ async def GenerateStream( # noqa: PLR0915, C901
352360
sampling_params,
353361
truncate_input_tokens,
354362
request.request.text,
363+
request.request.suffix,
355364
tokenizer,
356365
context,
357366
)
@@ -782,30 +791,47 @@ def _convert_tokens( # noqa: PLR0913
782791
)
783792
token_infos.append(token_info)
784793

785-
async def _validate_prompt_and_tokenize(
794+
async def _validate_prompt_and_tokenize( # noqa: PLR0913
786795
self,
787796
sampling_params: SamplingParams,
788797
truncate_input_tokens: int | None,
789798
prompt: str,
799+
suffix: str | None,
790800
tokenizer: AnyTokenizer,
791801
context: ServicerContext,
792802
) -> tuple[list[int], bool]:
793803
assert self.config is not None
794804

795-
max_model_len = self.config.max_model_len
796-
797-
tokenizer_kwargs: dict[str, Any] = {"add_special_tokens": ADD_SPECIAL_TOKENS}
798-
if truncate_input_tokens is not None:
799-
tokenizer_kwargs.update(
800-
{
801-
"truncation": True,
802-
"max_length": truncate_input_tokens,
803-
}
804-
)
805+
if suffix:
806+
if not (get_fim_encoder := self.get_fim_encoder):
807+
await context.abort(
808+
StatusCode.INVALID_ARGUMENT,
809+
"fim support must be enabled to use suffix",
810+
)
811+
if truncate_input_tokens is not None:
812+
await context.abort(
813+
StatusCode.INVALID_ARGUMENT,
814+
"truncate_input_tokens cannot be used with suffix",
815+
)
816+
fim_encoder = get_fim_encoder(tokenizer)
817+
input_ids = fim_encoder.encode_with_suffix(prefix=prompt, suffix=suffix)
818+
else:
819+
tokenizer_kwargs: dict[str, Any] = {
820+
"add_special_tokens": ADD_SPECIAL_TOKENS
821+
}
822+
if truncate_input_tokens is not None:
823+
tokenizer_kwargs.update(
824+
{
825+
"truncation": True,
826+
"max_length": truncate_input_tokens,
827+
}
828+
)
805829

806-
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
830+
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
807831
token_num = len(input_ids)
808832

833+
max_model_len = self.config.max_model_len
834+
809835
try:
810836
validate_input(sampling_params, token_num, max_model_len)
811837
except ValueError as tgis_validation_error:

src/vllm_tgis_adapter/grpc/pb/generation.proto

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ message BatchedGenerationResponse {
5151

5252
message GenerationRequest {
5353
string text = 2;
54+
/// Optional, for fill-in-middle
55+
string suffix = 3;
5456
}
5557

5658
message GenerationResponse {

0 commit comments

Comments
 (0)