Skip to content

Commit b7803bc

Browse files
Merge pull request #14778 from otaviofbrito/chore/fix-vllm-pasthrough
fix vllm passthrough
2 parents 88f9cad + ffd9117 commit b7803bc

File tree

4 files changed

+164
-16
lines changed

4 files changed

+164
-16
lines changed

litellm/llms/vllm/common_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,21 @@
1111

1212

1313
class VLLMError(BaseLLMException):
14-
pass
14+
def __init__(
15+
self,
16+
status_code: int,
17+
message: str,
18+
request: Optional[httpx.Request] = None,
19+
response: Optional[httpx.Response] = None,
20+
headers: Optional[Union[httpx.Headers, dict]] = None,
21+
):
22+
super().__init__(
23+
status_code=status_code,
24+
message=message,
25+
request=request,
26+
response=response,
27+
headers=headers,
28+
)
1529

1630

1731
class VLLMModelInfo(BaseLLMModelInfo):
@@ -25,7 +39,8 @@ def validate_environment(
2539
api_key: Optional[str] = None,
2640
api_base: Optional[str] = None,
2741
) -> dict:
28-
"""Google AI Studio sends api key in query params"""
42+
if api_key is not None:
43+
headers["x-api-key"] = api_key
2944
return headers
3045

3146
@staticmethod
@@ -53,7 +68,7 @@ def get_models(
5368
endpoint = "/v1/models"
5469
if api_base is None or api_key is None:
5570
raise ValueError(
56-
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
71+
"VLLM_API_BASE or VLLM_API_KEY is not set. Please set the environment variable, to query VLLM's `/models` endpoint."
5772
)
5873

5974
url = _add_path_to_api_base(api_base, endpoint)

litellm/proxy/common_utils/http_parsing_utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,16 @@ async def get_request_body(request: Request) -> Dict[str, Any]:
233233
"""
234234
Read the request body and parse it as JSON.
235235
"""
236-
if request.headers.get("content-type") == "application/json":
237-
return await _read_request_body(request)
238-
elif (
239-
request.headers.get("content-type") == "multipart/form-data"
240-
or request.headers.get("content-type") == "application/x-www-form-urlencoded"
241-
):
242-
return await get_form_data(request)
243-
else:
244-
raise ValueError(
245-
f"Unsupported content type: {request.headers.get('content-type')}"
246-
)
236+
if request.method == "POST":
237+
if request.headers.get("content-type", "") == "application/json":
238+
return await _read_request_body(request)
239+
elif (
240+
"multipart/form-data" in request.headers.get("content-type", "")
241+
or "application/x-www-form-urlencoded" in request.headers.get("content-type", "")
242+
):
243+
return await get_form_data(request)
244+
else:
245+
raise ValueError(
246+
f"Unsupported content type: {request.headers.get('content-type')}"
247+
)
248+
return {}

litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,16 @@ async def llm_passthrough_factory_proxy_route(
108108

109109
# Construct the full target URL using httpx
110110
base_url = httpx.URL(base_target_url)
111-
updated_url = base_url.copy_with(path=encoded_endpoint)
111+
# Join paths correctly by removing trailing/leading slashes as needed
112+
if not base_url.path or base_url.path == "/":
113+
# If base URL has no path, just use the new path
114+
updated_url = base_url.copy_with(path=encoded_endpoint)
115+
else:
116+
# Otherwise, combine the paths
117+
base_path = base_url.path.rstrip("/")
118+
clean_path = encoded_endpoint.lstrip("/")
119+
full_path = f"{base_path}/{clean_path}"
120+
updated_url = base_url.copy_with(path=full_path)
112121

113122
# Add or update query parameters
114123
provider_api_key = passthrough_endpoint_router.get_credentials(
@@ -130,7 +139,11 @@ async def llm_passthrough_factory_proxy_route(
130139
is_streaming_request = False
131140
# anthropic is streaming when 'stream' = True is in the body
132141
if request.method == "POST":
133-
_request_body = await request.json()
142+
if "multipart/form-data" not in request.headers.get("content-type", ""):
143+
_request_body = await request.json()
144+
else:
145+
_request_body = await get_form_data(request)
146+
134147
if _request_body.get("stream"):
135148
is_streaming_request = True
136149

tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
BaseOpenAIPassThroughHandler,
2020
RouteChecks,
2121
create_pass_through_route,
22+
llm_passthrough_factory_proxy_route,
23+
vllm_proxy_route,
2224
vertex_discovery_proxy_route,
2325
vertex_proxy_route,
2426
bedrock_llm_proxy_route,
@@ -914,3 +916,119 @@ async def test_bedrock_llm_proxy_route_regular_model(self):
914916
# For regular models, model should be just the model ID
915917
assert call_kwargs["model"] == "anthropic.claude-3-sonnet-20240229-v1:0"
916918
assert result == "success"
919+
920+
921+
class TestLLMPassthroughFactoryProxyRoute:
922+
@pytest.mark.asyncio
923+
async def test_llm_passthrough_factory_proxy_route_success(self):
924+
from litellm.types.utils import LlmProviders
925+
mock_request = MagicMock(spec=Request)
926+
mock_request.method = "POST"
927+
mock_request.json = AsyncMock(return_value={"stream": False})
928+
mock_fastapi_response = MagicMock(spec=Response)
929+
mock_user_api_key_dict = MagicMock()
930+
931+
with patch(
932+
"litellm.utils.ProviderConfigManager.get_provider_model_info"
933+
) as mock_get_provider, patch(
934+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials"
935+
) as mock_get_creds, patch(
936+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
937+
) as mock_create_route:
938+
mock_provider_config = MagicMock()
939+
mock_provider_config.get_api_base.return_value = "https://example.com/v1"
940+
mock_provider_config.validate_environment.return_value = {
941+
"x-api-key": "dummy"
942+
}
943+
mock_get_provider.return_value = mock_provider_config
944+
mock_get_creds.return_value = "dummy"
945+
946+
mock_endpoint_func = AsyncMock(return_value="success")
947+
mock_create_route.return_value = mock_endpoint_func
948+
949+
result = await llm_passthrough_factory_proxy_route(
950+
custom_llm_provider=LlmProviders.VLLM,
951+
endpoint="/chat/completions",
952+
request=mock_request,
953+
fastapi_response=mock_fastapi_response,
954+
user_api_key_dict=mock_user_api_key_dict,
955+
)
956+
957+
assert result == "success"
958+
mock_get_provider.assert_called_once_with(
959+
provider=litellm.LlmProviders(LlmProviders.VLLM), model=None
960+
)
961+
mock_get_creds.assert_called_once_with(
962+
custom_llm_provider=LlmProviders.VLLM, region_name=None
963+
)
964+
mock_create_route.assert_called_once_with(
965+
endpoint="/chat/completions",
966+
target="https://example.com/v1/chat/completions",
967+
custom_headers={"x-api-key": "dummy"},
968+
)
969+
mock_endpoint_func.assert_awaited_once()
970+
971+
972+
class TestVLLMProxyRoute:
973+
@pytest.mark.asyncio
974+
@patch(
975+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_request_body",
976+
return_value={"model": "router-model", "stream": False},
977+
)
978+
@patch(
979+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_passthrough_request_using_router_model",
980+
return_value=True,
981+
)
982+
@patch("litellm.proxy.proxy_server.llm_router")
983+
async def test_vllm_proxy_route_with_router_model(
984+
self, mock_llm_router, mock_is_router, mock_get_body
985+
):
986+
mock_request = MagicMock(spec=Request)
987+
mock_request.method = "POST"
988+
mock_request.headers = {"content-type": "application/json"}
989+
mock_request.query_params = {}
990+
mock_fastapi_response = MagicMock(spec=Response)
991+
mock_user_api_key_dict = MagicMock()
992+
mock_llm_router.allm_passthrough_route = AsyncMock(
993+
return_value=httpx.Response(200, json={"response": "success"})
994+
)
995+
996+
await vllm_proxy_route(
997+
endpoint="/chat/completions",
998+
request=mock_request,
999+
fastapi_response=mock_fastapi_response,
1000+
user_api_key_dict=mock_user_api_key_dict,
1001+
)
1002+
1003+
mock_is_router.assert_called_once()
1004+
mock_llm_router.allm_passthrough_route.assert_awaited_once()
1005+
1006+
@pytest.mark.asyncio
1007+
@patch(
1008+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_request_body",
1009+
return_value={"model": "other-model"},
1010+
)
1011+
@patch(
1012+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_passthrough_request_using_router_model",
1013+
return_value=False,
1014+
)
1015+
@patch(
1016+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.llm_passthrough_factory_proxy_route"
1017+
)
1018+
async def test_vllm_proxy_route_fallback_to_factory(
1019+
self, mock_factory_route, mock_is_router, mock_get_body
1020+
):
1021+
mock_request = MagicMock(spec=Request)
1022+
mock_fastapi_response = MagicMock(spec=Response)
1023+
mock_user_api_key_dict = MagicMock()
1024+
mock_factory_route.return_value = "factory_success"
1025+
1026+
result = await vllm_proxy_route(
1027+
endpoint="/chat/completions",
1028+
request=mock_request,
1029+
fastapi_response=mock_fastapi_response,
1030+
user_api_key_dict=mock_user_api_key_dict,
1031+
)
1032+
1033+
assert result == "factory_success"
1034+
mock_factory_route.assert_awaited_once()

0 commit comments

Comments
 (0)