Skip to content

Commit 33a11c4

Browse files
fix vllm passthrough
1 parent 52a56bd commit 33a11c4

File tree

4 files changed

+180
-16
lines changed

4 files changed

+180
-16
lines changed

litellm/llms/vllm/common_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,21 @@
1111

1212

1313
class VLLMError(BaseLLMException):
14-
pass
14+
def __init__(
15+
self,
16+
status_code: int,
17+
message: str,
18+
request: Optional[httpx.Request] = None,
19+
response: Optional[httpx.Response] = None,
20+
headers: Optional[Union[httpx.Headers, dict]] = None,
21+
):
22+
super().__init__(
23+
status_code=status_code,
24+
message=message,
25+
request=request,
26+
response=response,
27+
headers=headers,
28+
)
1529

1630

1731
class VLLMModelInfo(BaseLLMModelInfo):
@@ -25,7 +39,8 @@ def validate_environment(
2539
api_key: Optional[str] = None,
2640
api_base: Optional[str] = None,
2741
) -> dict:
28-
"""Google AI Studio sends api key in query params"""
42+
if api_key is not None:
43+
headers["x-api-key"] = api_key
2944
return headers
3045

3146
@staticmethod
@@ -53,7 +68,7 @@ def get_models(
5368
endpoint = "/v1/models"
5469
if api_base is None or api_key is None:
5570
raise ValueError(
56-
"GEMINI_API_BASE or GEMINI_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
71+
"VLLM_API_BASE or VLLM_API_KEY is not set. Please set the environment variable, to query VLLM's `/models` endpoint."
5772
)
5873

5974
url = _add_path_to_api_base(api_base, endpoint)

litellm/proxy/common_utils/http_parsing_utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -233,14 +233,16 @@ async def get_request_body(request: Request) -> Dict[str, Any]:
233233
"""
234234
Read the request body and parse it as JSON.
235235
"""
236-
if request.headers.get("content-type") == "application/json":
237-
return await _read_request_body(request)
238-
elif (
239-
request.headers.get("content-type") == "multipart/form-data"
240-
or request.headers.get("content-type") == "application/x-www-form-urlencoded"
241-
):
242-
return await get_form_data(request)
243-
else:
244-
raise ValueError(
245-
f"Unsupported content type: {request.headers.get('content-type')}"
246-
)
236+
if request.method == "POST":
237+
if request.headers.get("content-type", "") == "application/json":
238+
return await _read_request_body(request)
239+
elif (
240+
"multipart/form-data" in request.headers.get("content-type", "")
241+
or "application/x-www-form-urlencoded" in request.headers.get("content-type", "")
242+
):
243+
return await get_form_data(request)
244+
else:
245+
raise ValueError(
246+
f"Unsupported content type: {request.headers.get('content-type')}"
247+
)
248+
return {}

litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,16 @@ async def llm_passthrough_factory_proxy_route(
108108

109109
# Construct the full target URL using httpx
110110
base_url = httpx.URL(base_target_url)
111-
updated_url = base_url.copy_with(path=encoded_endpoint)
111+
# Join paths correctly by removing trailing/leading slashes as needed
112+
if not base_url.path or base_url.path == "/":
113+
# If base URL has no path, just use the new path
114+
updated_url = base_url.copy_with(path=encoded_endpoint)
115+
else:
116+
# Otherwise, combine the paths
117+
base_path = base_url.path.rstrip("/")
118+
clean_path = encoded_endpoint.lstrip("/")
119+
full_path = f"{base_path}/{clean_path}"
120+
updated_url = base_url.copy_with(path=full_path)
112121

113122
# Add or update query parameters
114123
provider_api_key = passthrough_endpoint_router.get_credentials(
@@ -130,7 +139,11 @@ async def llm_passthrough_factory_proxy_route(
130139
is_streaming_request = False
131140
# anthropic is streaming when 'stream' = True is in the body
132141
if request.method == "POST":
133-
_request_body = await request.json()
142+
if "multipart/form-data" not in request.headers.get("content-type", ""):
143+
_request_body = await request.json()
144+
else:
145+
_request_body = dict(request._form)
146+
134147
if _request_body.get("stream"):
135148
is_streaming_request = True
136149

tests/test_litellm/proxy/pass_through_endpoints/test_llm_pass_through_endpoints.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
BaseOpenAIPassThroughHandler,
2020
RouteChecks,
2121
create_pass_through_route,
22+
llm_passthrough_factory_proxy_route,
23+
vllm_proxy_route,
2224
vertex_discovery_proxy_route,
2325
vertex_proxy_route,
2426
bedrock_llm_proxy_route,
@@ -914,3 +916,135 @@ async def test_bedrock_llm_proxy_route_regular_model(self):
914916
# For regular models, model should be just the model ID
915917
assert call_kwargs["model"] == "anthropic.claude-3-sonnet-20240229-v1:0"
916918
assert result == "success"
919+
920+
921+
class TestLLMPassthroughFactoryProxyRoute:
922+
@pytest.mark.asyncio
923+
async def test_llm_passthrough_factory_proxy_route_success(self):
924+
mock_request = MagicMock(spec=Request)
925+
mock_request.method = "POST"
926+
mock_request.json = AsyncMock(return_value={"stream": False})
927+
mock_fastapi_response = MagicMock(spec=Response)
928+
mock_user_api_key_dict = MagicMock()
929+
930+
with patch(
931+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_model_info"
932+
) as mock_get_provider, patch(
933+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.passthrough_endpoint_router.get_credentials"
934+
) as mock_get_creds, patch(
935+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route"
936+
) as mock_create_route:
937+
mock_provider_config = MagicMock()
938+
mock_provider_config.get_api_base.return_value = "https://example.com/v1"
939+
mock_provider_config.validate_environment.return_value = {
940+
"Authorization": "Bearer test-key"
941+
}
942+
mock_get_provider.return_value = mock_provider_config
943+
mock_get_creds.return_value = "test-api-key"
944+
945+
mock_endpoint_func = AsyncMock(return_value="success")
946+
mock_create_route.return_value = mock_endpoint_func
947+
948+
result = await llm_passthrough_factory_proxy_route(
949+
custom_llm_provider="custom_provider",
950+
endpoint="/chat/completions",
951+
request=mock_request,
952+
fastapi_response=mock_fastapi_response,
953+
user_api_key_dict=mock_user_api_key_dict,
954+
)
955+
956+
assert result == "success"
957+
mock_get_provider.assert_called_once_with(
958+
provider=litellm.LlmProviders("custom_provider"), model=None
959+
)
960+
mock_get_creds.assert_called_once_with(
961+
custom_llm_provider="custom_provider", region_name=None
962+
)
963+
mock_create_route.assert_called_once_with(
964+
endpoint="/chat/completions",
965+
target="https://example.com/v1/chat/completions",
966+
custom_headers={"Authorization": "Bearer test-key"},
967+
)
968+
mock_endpoint_func.assert_awaited_once()
969+
970+
971+
class TestVLLMProxyRoute:
972+
@pytest.mark.asyncio
973+
async def test_vllm_proxy_route_with_router_model(self):
974+
mock_request = MagicMock(spec=Request)
975+
mock_request.method = "POST"
976+
mock_request.headers = {"content-type": "application/json"}
977+
mock_request.query_params = {}
978+
979+
mock_fastapi_response = MagicMock(spec=Response)
980+
mock_user_api_key_dict = MagicMock()
981+
982+
with patch(
983+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_request_body",
984+
return_value={"model": "router-model", "stream": False},
985+
), patch(
986+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_passthrough_request_using_router_model",
987+
return_value=True,
988+
) as mock_is_router, patch(
989+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.llm_router"
990+
) as mock_llm_router:
991+
mock_llm_router.allm_passthrough_route = AsyncMock()
992+
mock_response = httpx.Response(200, json={"response": "success"})
993+
mock_llm_router.allm_passthrough_route.return_value = mock_response
994+
995+
await vllm_proxy_route(
996+
endpoint="/chat/completions",
997+
request=mock_request,
998+
fastapi_response=mock_fastapi_response,
999+
user_api_key_dict=mock_user_api_key_dict,
1000+
)
1001+
1002+
mock_is_router.assert_called_once()
1003+
mock_llm_router.allm_passthrough_route.assert_awaited_once_with(
1004+
model="router-model",
1005+
method="POST",
1006+
endpoint="/chat/completions",
1007+
request_query_params={},
1008+
request_headers={"content-type": "application/json"},
1009+
stream=False,
1010+
content=None,
1011+
data=None,
1012+
files=None,
1013+
json={"model": "router-model", "stream": False},
1014+
params=None,
1015+
headers=None,
1016+
cookies=None,
1017+
)
1018+
1019+
@pytest.mark.asyncio
1020+
async def test_vllm_proxy_route_fallback_to_factory(self):
1021+
mock_request = MagicMock(spec=Request)
1022+
mock_fastapi_response = MagicMock(spec=Response)
1023+
mock_user_api_key_dict = MagicMock()
1024+
1025+
with patch(
1026+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_request_body",
1027+
return_value={"model": "other-model"},
1028+
), patch(
1029+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.is_passthrough_request_using_router_model",
1030+
return_value=False,
1031+
), patch(
1032+
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.llm_passthrough_factory_proxy_route"
1033+
) as mock_factory_route:
1034+
mock_factory_route.return_value = "factory_success"
1035+
1036+
result = await vllm_proxy_route(
1037+
endpoint="/chat/completions",
1038+
request=mock_request,
1039+
fastapi_response=mock_fastapi_response,
1040+
user_api_key_dict=mock_user_api_key_dict,
1041+
)
1042+
1043+
assert result == "factory_success"
1044+
mock_factory_route.assert_awaited_once_with(
1045+
endpoint="/chat/completions",
1046+
request=mock_request,
1047+
fastapi_response=mock_fastapi_response,
1048+
user_api_key_dict=mock_user_api_key_dict,
1049+
custom_llm_provider="vllm",
1050+
)

0 commit comments

Comments
 (0)