Skip to content

Commit 1ab51e1

Browse files
akramiamemilio
authored andcommitted
feat: Add dynamic authentication token forwarding support for vLLM (llamastack#3388)
# What does this PR do? *Add dynamic authentication token forwarding support for vLLM provider* This enables per-request authentication tokens for vLLM providers, supporting use cases like RAG operations where different requests may need different authentication tokens. The implementation follows the same pattern as other providers like Together AI, Fireworks, and Passthrough. - Add LiteLLMOpenAIMixin that manages the vllm_api_token properly Usage: - Static: VLLM_API_TOKEN env var or config.api_token - Dynamic: X-LlamaStack-Provider-Data header with vllm_api_token All existing functionality is preserved while adding new dynamic capabilities. <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> ``` curl -X POST "http://localhost:8000/v1/chat/completions" -H "Authorization: Bearer my-dynamic-token" \ -H "X-LlamaStack-Provider-Data: {\"vllm_api_token\": \"Bearer my-dynamic-token\", \"vllm_url\": \"http://dynamic-server:8000\"}" \ -H "Content-Type: application/json" \ -d '{"model": "llama-3.1-8b", "messages": [{"role": "user", "content": "Hello!"}]}' ``` --------- Signed-off-by: Akram Ben Aissi <[email protected]>
1 parent b35f588 commit 1ab51e1

File tree

4 files changed

+219
-48
lines changed

4 files changed

+219
-48
lines changed

llama_stack/providers/registry/inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def available_providers() -> list[ProviderSpec]:
7878
pip_packages=[],
7979
module="llama_stack.providers.remote.inference.vllm",
8080
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
81+
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
8182
description="Remote vLLM inference provider for connecting to vLLM servers.",
8283
),
8384
),

llama_stack/providers/remote/inference/vllm/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66

7+
from pydantic import BaseModel
8+
79
from .config import VLLMInferenceAdapterConfig
810

911

12+
class VLLMProviderDataValidator(BaseModel):
13+
vllm_api_token: str | None = None
14+
15+
1016
async def get_adapter_impl(config: VLLMInferenceAdapterConfig, _deps):
1117
from .vllm import VLLMInferenceAdapter
1218

llama_stack/providers/remote/inference/vllm/vllm.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
# This source code is licensed under the terms described in the LICENSE file in
55
# the root directory of this source tree.
66
import json
7-
from collections.abc import AsyncGenerator
7+
from collections.abc import AsyncGenerator, AsyncIterator
88
from typing import Any
9+
from urllib.parse import urljoin
910

1011
import httpx
1112
from openai import APIConnectionError, AsyncOpenAI
@@ -55,13 +56,15 @@
5556
HealthStatus,
5657
ModelsProtocolPrivate,
5758
)
59+
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
5860
from llama_stack.providers.utils.inference.model_registry import (
5961
ModelRegistryHelper,
6062
build_hf_repo_model_entry,
6163
)
6264
from llama_stack.providers.utils.inference.openai_compat import (
6365
UnparseableToolCall,
6466
convert_message_to_openai_dict,
67+
convert_openai_chat_completion_stream,
6568
convert_tool_call,
6669
get_sampling_options,
6770
process_chat_completion_stream_response,
@@ -281,22 +284,39 @@ async def _process_vllm_chat_completion_stream_response(
281284
yield c
282285

283286

284-
class VLLMInferenceAdapter(OpenAIMixin, Inference, ModelsProtocolPrivate):
287+
class VLLMInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin, Inference, ModelsProtocolPrivate):
285288
# automatically set by the resolver when instantiating the provider
286289
__provider_id__: str
287290
model_store: ModelStore | None = None
288291

289292
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
293+
LiteLLMOpenAIMixin.__init__(
294+
self,
295+
build_hf_repo_model_entries(),
296+
litellm_provider_name="vllm",
297+
api_key_from_config=config.api_token,
298+
provider_data_api_key_field="vllm_api_token",
299+
openai_compat_api_base=config.url,
300+
)
290301
self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries())
291302
self.config = config
292303

304+
get_api_key = LiteLLMOpenAIMixin.get_api_key
305+
306+
def get_base_url(self) -> str:
307+
"""Get the base URL from config."""
308+
if not self.config.url:
309+
raise ValueError("No base URL configured")
310+
return self.config.url
311+
293312
async def initialize(self) -> None:
294313
if not self.config.url:
295314
raise ValueError(
296315
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
297316
)
298317

299318
async def should_refresh_models(self) -> bool:
319+
# Strictly respecting the refresh_models directive
300320
return self.config.refresh_models
301321

302322
async def list_models(self) -> list[Model] | None:
@@ -325,13 +345,19 @@ async def health(self) -> HealthResponse:
325345
Performs a health check by verifying connectivity to the remote vLLM server.
326346
This method is used by the Provider API to verify
327347
that the service is running correctly.
348+
Uses the unauthenticated /health endpoint.
328349
Returns:
329350
330351
HealthResponse: A dictionary containing the health status.
331352
"""
332353
try:
333-
_ = [m async for m in self.client.models.list()] # Ensure the client is initialized
334-
return HealthResponse(status=HealthStatus.OK)
354+
base_url = self.get_base_url()
355+
health_url = urljoin(base_url, "health")
356+
357+
async with httpx.AsyncClient() as client:
358+
response = await client.get(health_url)
359+
response.raise_for_status()
360+
return HealthResponse(status=HealthStatus.OK)
335361
except Exception as e:
336362
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
337363

@@ -340,16 +366,10 @@ async def _get_model(self, model_id: str) -> Model:
340366
raise ValueError("Model store not set")
341367
return await self.model_store.get_model(model_id)
342368

343-
def get_api_key(self):
344-
return self.config.api_token
345-
346-
def get_base_url(self):
347-
return self.config.url
348-
349369
def get_extra_client_params(self):
350370
return {"http_client": httpx.AsyncClient(verify=self.config.tls_verify)}
351371

352-
async def completion(
372+
async def completion( # type: ignore[override] # Return type more specific than base class which is allows for both streaming and non-streaming responses.
353373
self,
354374
model_id: str,
355375
content: InterleavedContent,
@@ -411,13 +431,14 @@ async def chat_completion(
411431
tool_config=tool_config,
412432
)
413433
if stream:
414-
return self._stream_chat_completion(request, self.client)
434+
return self._stream_chat_completion_with_client(request, self.client)
415435
else:
416436
return await self._nonstream_chat_completion(request, self.client)
417437

418438
async def _nonstream_chat_completion(
419439
self, request: ChatCompletionRequest, client: AsyncOpenAI
420440
) -> ChatCompletionResponse:
441+
assert self.client is not None
421442
params = await self._get_params(request)
422443
r = await client.chat.completions.create(**params)
423444
choice = r.choices[0]
@@ -431,9 +452,24 @@ async def _nonstream_chat_completion(
431452
)
432453
return result
433454

434-
async def _stream_chat_completion(
455+
async def _stream_chat_completion(self, response: Any) -> AsyncIterator[ChatCompletionResponseStreamChunk]:
456+
# This method is called from LiteLLMOpenAIMixin.chat_completion
457+
# The response parameter contains the litellm response
458+
# We need to convert it to our format
459+
async def _stream_generator():
460+
async for chunk in response:
461+
yield chunk
462+
463+
async for chunk in convert_openai_chat_completion_stream(
464+
_stream_generator(), enable_incremental_tool_calls=True
465+
):
466+
yield chunk
467+
468+
async def _stream_chat_completion_with_client(
435469
self, request: ChatCompletionRequest, client: AsyncOpenAI
436470
) -> AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
471+
"""Helper method for streaming with explicit client parameter."""
472+
assert self.client is not None
437473
params = await self._get_params(request)
438474

439475
stream = await client.chat.completions.create(**params)
@@ -445,15 +481,17 @@ async def _stream_chat_completion(
445481
yield chunk
446482

447483
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
448-
assert self.client is not None
484+
if self.client is None:
485+
raise RuntimeError("Client is not initialized")
449486
params = await self._get_params(request)
450487
r = await self.client.completions.create(**params)
451488
return process_completion_response(r)
452489

453490
async def _stream_completion(
454491
self, request: CompletionRequest
455492
) -> AsyncGenerator[CompletionResponseStreamChunk, None]:
456-
assert self.client is not None
493+
if self.client is None:
494+
raise RuntimeError("Client is not initialized")
457495
params = await self._get_params(request)
458496

459497
stream = await self.client.completions.create(**params)

0 commit comments

Comments
 (0)