diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index 6b3b06e4af6c..a1f11b2183a3 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -68,7 +68,9 @@ async def _read_request_body(request: Optional[Request]) -> Dict: parsed_body = json.loads(body_str) except json.JSONDecodeError: # If both orjson and json.loads fail, throw a proper error - verbose_proxy_logger.error(f"Invalid JSON payload received: {str(e)}") + verbose_proxy_logger.error( + f"Invalid JSON payload received: {str(e)}" + ) raise ProxyException( message=f"Invalid JSON payload: {str(e)}", type="invalid_request_error", @@ -104,6 +106,7 @@ def _safe_get_request_parsed_body(request: Optional[Request]) -> Optional[dict]: return {key: parsed_body[key] for key in accepted_keys} return None + def _safe_get_request_query_params(request: Optional[Request]) -> Dict: if request is None: return {} @@ -117,6 +120,7 @@ def _safe_get_request_query_params(request: Optional[Request]) -> Dict: ) return {} + def _safe_set_request_parsed_body( request: Optional[Request], parsed_body: dict, @@ -239,9 +243,10 @@ async def get_request_body(request: Request) -> Dict[str, Any]: if request.method == "POST": if request.headers.get("content-type", "") == "application/json": return await _read_request_body(request) - elif ( - "multipart/form-data" in request.headers.get("content-type", "") - or "application/x-www-form-urlencoded" in request.headers.get("content-type", "") + elif "multipart/form-data" in request.headers.get( + "content-type", "" + ) or "application/x-www-form-urlencoded" in request.headers.get( + "content-type", "" ): return await get_form_data(request) else: @@ -254,10 +259,10 @@ async def get_request_body(request: Request) -> Dict[str, Any]: def get_tags_from_request_body(request_body: dict) -> List[str]: """ Extract tags from request body metadata. - + Args: request_body: The request body dictionary - + Returns: List of tag names (strings), empty list if no valid tags found """ @@ -276,4 +281,3 @@ def get_tags_from_request_body(request_body: dict) -> List[str]: combined_tags.extend(tags_in_request_body) ###################################### return [tag for tag in combined_tags if isinstance(tag, str)] - diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 6f9f04e5cc27..d833ec434131 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -38,6 +38,7 @@ from litellm.secret_managers.main import get_secret_str from .passthrough_endpoint_router import PassthroughEndpointRouter +import asyncio vertex_llm_base = VertexBase() router = APIRouter() @@ -86,6 +87,8 @@ async def llm_passthrough_factory_proxy_route( """ Factory function for creating pass-through endpoints for LLM providers. """ + from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import \ + passthrough_endpoint_router from litellm.types.utils import LlmProviders from litellm.utils import ProviderConfigManager @@ -124,11 +127,38 @@ async def llm_passthrough_factory_proxy_route( full_path = f"{base_path}/{clean_path}" updated_url = base_url.copy_with(path=full_path) - # Add or update query parameters - provider_api_key = passthrough_endpoint_router.get_credentials( - custom_llm_provider=custom_llm_provider, - region_name=None, - ) + # Use asyncio.gather for concurrent I/O if POST (stream-detection and credential lookup can be parallelized) + provider_api_key, is_streaming_request = None, False + if request.method == "POST": + content_type = request.headers.get("content-type", "") + gather_tasks = [] + + # Both tasks below are I/O bound, can run concurrently + gather_tasks.append( + passthrough_endpoint_router.get_credentials( + custom_llm_provider=custom_llm_provider, + region_name=None, + ) + ) + + if "multipart/form-data" not in content_type: + gather_tasks.append(request.json()) + else: + gather_tasks.append(get_form_data(request)) + + # Run the credential and request-body lookup in parallel + results = await asyncio.gather(*gather_tasks) + provider_api_key = results[0] + _request_body = results[1] + + if _request_body.get("stream"): + is_streaming_request = True + else: + # For non-POST, credentials are still required (serial since only one call) + provider_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider=custom_llm_provider, + region_name=None, + ) auth_headers = provider_config.validate_environment( headers={}, @@ -140,18 +170,6 @@ async def llm_passthrough_factory_proxy_route( api_base=base_target_url, ) - ## check for streaming - is_streaming_request = False - # anthropic is streaming when 'stream' = True is in the body - if request.method == "POST": - if "multipart/form-data" not in request.headers.get("content-type", ""): - _request_body = await request.json() - else: - _request_body = await get_form_data(request) - - if _request_body.get("stream"): - is_streaming_request = True - ## CREATE PASS-THROUGH endpoint_func = create_pass_through_route( endpoint=endpoint,