Skip to content

Commit b7c8d56

Browse files
jberkhahnjoerunde
authored andcommitted
Stop caching LoRA requests and query vllm api server cache to check if it contains incoming LoRA requests
1 parent 3920c90 commit b7c8d56

File tree

5 files changed

+269
-44
lines changed

5 files changed

+269
-44
lines changed

src/vllm_tgis_adapter/__main__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from vllm_tgis_adapter.tgis_utils.logs import add_logging_wrappers
2020

2121
from .grpc import run_grpc_server
22-
from .http import run_http_server
22+
from .http import build_http_server, run_http_server
2323
from .logging import DEFAULT_LOGGER_NAME, init_logger
2424
from .tgis_utils.args import EnvVarArgumentParser, add_tgis_args, postprocess_tgis_args
2525
from .utils import check_for_failed_tasks, write_termination_log
@@ -43,15 +43,16 @@ async def start_servers(args: argparse.Namespace) -> None:
4343
async with build_async_engine_client(args) as engine:
4444
add_logging_wrappers(engine)
4545

46+
vllm_server = await build_http_server(args, engine)
4647
http_server_task = loop.create_task(
47-
run_http_server(args, engine, sock),
48+
run_http_server(args, vllm_server, sock),
4849
name="http_server",
4950
)
5051
# The http server task will catch interrupt signals for us
5152
tasks.append(http_server_task)
5253

5354
grpc_server_task = loop.create_task(
54-
run_grpc_server(args, engine),
55+
run_grpc_server(args, engine, vllm_server),
5556
name="grpc_server",
5657
)
5758
tasks.append(grpc_server_task)

src/vllm_tgis_adapter/grpc/adapters.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,29 @@
1515
from pathlib import Path
1616
from typing import TYPE_CHECKING
1717

18-
from vllm.lora.request import LoRARequest
18+
from vllm.entrypoints.openai.protocol import ErrorResponse
1919
from vllm.prompt_adapter.request import PromptAdapterRequest
2020

2121
from vllm_tgis_adapter.logging import init_logger
2222
from vllm_tgis_adapter.tgis_utils.convert_pt_to_prompt import convert_pt_to_peft
2323

2424
from .validation import TGISValidationError
2525

26+
try:
27+
from vllm.entrypoints.openai.protocol import LoadLoRAAdapterRequest
28+
except ImportError:
29+
from vllm.entrypoints.openai.protocol import (
30+
LoadLoraAdapterRequest as LoadLoRAAdapterRequest,
31+
)
32+
2633
if TYPE_CHECKING:
2734
from vllm.entrypoints.grpc.pb.generation_pb2 import (
2835
BatchedGenerationRequest,
2936
BatchedTokenizeRequest,
3037
SingleGenerationRequest,
3138
)
39+
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
40+
from vllm.lora.request import LoRARequest
3241

3342
global_thread_pool = None # used for loading adapter files from disk
3443

@@ -49,7 +58,8 @@ class AdapterMetadata:
4958
class AdapterStore:
5059
cache_path: str # Path to local store of adapters to load from
5160
adapters: dict[str, AdapterMetadata]
52-
next_unique_id: int = 1
61+
# Pick a large number to avoid colliding with vllm's adapter IDs
62+
next_unique_id: int = 1000001
5363
load_locks: dict[str, asyncio.Lock] = dataclasses.field(default_factory=dict)
5464

5565

@@ -58,6 +68,7 @@ async def validate_adapters(
5868
| BatchedGenerationRequest
5969
| BatchedTokenizeRequest,
6070
adapter_store: AdapterStore | None,
71+
vllm_model_handler: OpenAIServingModels,
6172
) -> dict[str, LoRARequest | PromptAdapterRequest]:
6273
"""Validate the adapters.
6374
@@ -81,6 +92,12 @@ async def validate_adapters(
8192

8293
# Guard against concurrent access for the same adapter
8394
async with adapter_store.load_locks.setdefault(adapter_id, asyncio.Lock()):
95+
# Check VLLM server lora cache if this request matches an existing
96+
# LoRA adapter
97+
for existing_lora_request in vllm_model_handler.lora_requests:
98+
if existing_lora_request.lora_name == adapter_id:
99+
return {"lora_request": existing_lora_request}
100+
84101
# If not already cached, we need to validate that files exist and
85102
# grab the type out of the adapter_config.json file
86103
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
@@ -107,16 +124,19 @@ async def validate_adapters(
107124
)
108125

109126
# Add to cache
127+
# Query vllm's cache for lora requests
128+
if adapter_metadata.adapter_type == "LORA":
129+
lora_request = await _load_lora_adapter(
130+
request,
131+
adapter_id,
132+
adapter_metadata,
133+
vllm_model_handler,
134+
)
135+
return {"lora_request": lora_request}
136+
# Use our cache for everything else
110137
adapter_store.adapters[adapter_id] = adapter_metadata
111138

112139
# Build the proper vllm request object
113-
if adapter_metadata.adapter_type == "LORA":
114-
lora_request = LoRARequest(
115-
lora_name=adapter_id,
116-
lora_int_id=adapter_metadata.unique_id,
117-
lora_path=adapter_metadata.full_path,
118-
)
119-
return {"lora_request": lora_request}
120140
if adapter_metadata.adapter_type == "PROMPT_TUNING":
121141
prompt_adapter_request = PromptAdapterRequest(
122142
prompt_adapter_id=adapter_metadata.unique_id,
@@ -126,12 +146,36 @@ async def validate_adapters(
126146
"num_virtual_tokens", 0
127147
),
128148
)
129-
return {"prompt_adapter_request": prompt_adapter_request}
149+
return {"prompt_adapter_request": prompt_adapter_request}
130150

131151
# All other types unsupported
132152
TGISValidationError.AdapterUnsupported.error(adapter_metadata.adapter_type) # noqa: RET503
133153

134154

155+
async def _load_lora_adapter(
156+
request: SingleGenerationRequest
157+
| BatchedGenerationRequest
158+
| BatchedTokenizeRequest,
159+
adapter_id: str,
160+
adapter_metadata: AdapterMetadata,
161+
vllm_model_handler: OpenAIServingModels,
162+
) -> LoRARequest:
163+
load_request = LoadLoRAAdapterRequest(
164+
lora_path=adapter_metadata.full_path,
165+
lora_name=adapter_id,
166+
)
167+
load_result = await vllm_model_handler.load_lora_adapter(
168+
request=load_request,
169+
base_model_name=request.model_id,
170+
)
171+
if isinstance(load_result, ErrorResponse):
172+
raise ValueError(load_result.message) ## noqa: TRY004
173+
for existing_lora_request in vllm_model_handler.lora_requests:
174+
if existing_lora_request.lora_name == adapter_id:
175+
return existing_lora_request
176+
raise RuntimeError("vllm server failed to load LoRA adapter")
177+
178+
135179
def _load_adapter_metadata(adapter_id: str, adapter_path: str, unique_id: int) -> dict:
136180
"""Get adapter metadata from files.
137181

src/vllm_tgis_adapter/grpc/grpc_server.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,12 @@
5959
import argparse
6060
from collections.abc import AsyncIterator, MutableSequence
6161

62+
from fastapi import FastAPI
6263
from grpc.aio import ServicerContext
6364
from vllm import CompletionOutput, RequestOutput
6465
from vllm.config import ModelConfig
6566
from vllm.engine.protocol import EngineClient
67+
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
6668
from vllm.lora.request import LoRARequest
6769
from vllm.sequence import Logprob
6870
from vllm.transformers_utils.tokenizer import AnyTokenizer
@@ -167,9 +169,11 @@ def __init__(
167169
args: argparse.Namespace,
168170
health_servicer: health.HealthServicer,
169171
stop_event: asyncio.Event,
172+
vllm_server: FastAPI,
170173
):
171174
self.engine: EngineClient = engine
172175
self.stop_event = stop_event
176+
self.vllm_server = vllm_server
173177

174178
# This is set in post_init()
175179
self.config: ModelConfig | None = None
@@ -218,7 +222,11 @@ async def Generate(
218222
start_time = time.time()
219223
service_metrics.count_generate_request(len(request.requests))
220224
request_id = self.request_id(context)
221-
kwargs = await self._validate_adapters(request, context)
225+
kwargs = await self._validate_adapters(
226+
request,
227+
context,
228+
self.vllm_server.state.openai_serving_models,
229+
)
222230
tokenizer = await self._get_tokenizer(kwargs)
223231

224232
sampling_params, deadline = await self._validate_and_convert_params(
@@ -308,7 +316,11 @@ async def GenerateStream( # noqa: PLR0915, C901
308316
start_time = time.time()
309317
service_metrics.count_generate_request()
310318
request_id = self.request_id(context)
311-
adapter_kwargs = await self._validate_adapters(request, context)
319+
adapter_kwargs = await self._validate_adapters(
320+
request,
321+
context,
322+
self.vllm_server.state.openai_serving_models,
323+
)
312324
tokenizer = await self._get_tokenizer(adapter_kwargs)
313325

314326
sampling_params, deadline = await self._validate_and_convert_params(
@@ -628,10 +640,13 @@ async def _validate_adapters(
628640
| BatchedGenerationRequest
629641
| BatchedTokenizeRequest,
630642
context: ServicerContext,
643+
vllm_model_handler: OpenAIServingModels,
631644
) -> dict[str, LoRARequest | PromptAdapterRequest]:
632645
try:
633646
adapters = await validate_adapters(
634-
request=request, adapter_store=self.adapter_store
647+
request=request,
648+
adapter_store=self.adapter_store,
649+
vllm_model_handler=vllm_model_handler,
635650
)
636651
except ValueError as e:
637652
service_metrics.count_request_failure(FailureReasonLabel.VALIDATION)
@@ -812,7 +827,11 @@ async def Tokenize(
812827
service_metrics.count_tokenization_request(request)
813828

814829
# TODO simplify to only check for lora adapter
815-
adapter_kwargs = await self._validate_adapters(request, context)
830+
adapter_kwargs = await self._validate_adapters(
831+
request,
832+
context,
833+
self.vllm_server.state.openai_serving_models,
834+
)
816835
tokenizer = await self._get_tokenizer(adapter_kwargs)
817836

818837
responses: list[TokenizeResponse] = []
@@ -886,13 +905,20 @@ async def start_grpc_server(
886905
args: argparse.Namespace,
887906
engine: EngineClient,
888907
stop_event: asyncio.Event,
908+
vllm_server: FastAPI,
889909
) -> aio.Server:
890910
server = aio.server()
891911

892912
health_servicer = health.HealthServicer()
893913
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
894914

895-
generation = TextGenerationService(engine, args, health_servicer, stop_event)
915+
generation = TextGenerationService(
916+
engine,
917+
args,
918+
health_servicer,
919+
stop_event,
920+
vllm_server,
921+
)
896922
await generation.post_init()
897923
generation_pb2_grpc.add_GenerationServiceServicer_to_server(generation, server)
898924

@@ -951,9 +977,10 @@ async def start_grpc_server(
951977
async def run_grpc_server(
952978
args: argparse.Namespace,
953979
engine: EngineClient,
980+
vllm_server: FastAPI,
954981
) -> None:
955982
stop_event = asyncio.Event()
956-
server = await start_grpc_server(args, engine, stop_event)
983+
server = await start_grpc_server(args, engine, stop_event, vllm_server)
957984

958985
# Add a task to watch for the stop event, so that the server can kill
959986
# itself from within its own handlers

src/vllm_tgis_adapter/http.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import argparse
1414
import socket
1515

16-
from fastapi import Request, Response
16+
from fastapi import FastAPI, Request, Response
1717
from vllm.engine.async_llm_engine import AsyncLLMEngine
1818
from vllm.engine.protocol import AsyncEngineClient
1919

@@ -22,14 +22,12 @@
2222
logger = init_logger(__name__)
2323

2424

25-
async def run_http_server(
25+
async def build_http_server(
2626
args: argparse.Namespace,
2727
engine: AsyncLLMEngine | AsyncEngineClient,
28-
sock: socket.socket | None = None,
29-
**uvicorn_kwargs, # noqa: ANN003
30-
) -> None:
31-
# modified copy of vllm.entrypoints.openai.api_server.run_server that
32-
# allows passing of the engine
28+
) -> FastAPI:
29+
# builds the vllm api server so we can pass reference to it
30+
# within the tgis adapter
3331

3432
app = build_app(args)
3533

@@ -53,6 +51,18 @@ async def set_correlation_id(request: Request, call_next: Callable) -> Response:
5351
if inspect.isawaitable(maybe_coroutine):
5452
await maybe_coroutine
5553

54+
return app
55+
56+
57+
async def run_http_server(
58+
args: argparse.Namespace,
59+
app: FastAPI,
60+
sock: socket.socket | None = None,
61+
**uvicorn_kwargs, # noqa: ANN003
62+
) -> None:
63+
# modified copy of vllm.entrypoints.openai.api_server.run_server that
64+
# allows passing of the engine
65+
5666
serve_kwargs = {
5767
"host": args.host,
5868
"port": args.port,

0 commit comments

Comments
 (0)