Skip to content

Commit d6be26d

Browse files
committed
Merge branch 'main' into ISSUE-15105
2 parents 7ef71d4 + 7e56600 commit d6be26d

File tree

8 files changed

+675
-336
lines changed

8 files changed

+675
-336
lines changed

litellm/proxy/common_utils/callback_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,8 @@ def initialize_callbacks_on_proxy( # noqa: PLR0915
289289

290290
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
291291
_litellm_params = kwargs.get("litellm_params", None) or {}
292-
_metadata = _litellm_params.get(get_metadata_variable_name_from_kwargs(kwargs)) or {}
293-
_model_group = _metadata.get("model_group", None)
292+
_metadata = _litellm_params.get(get_metadata_variable_name_from_litellm_params(_litellm_params)) or {}
293+
_model_group = _metadata.get("model_group", None) or kwargs.get("model", None)
294294
if _model_group is not None:
295295
return _model_group
296296

@@ -367,8 +367,8 @@ def add_guardrail_to_applied_guardrails_header(
367367
_metadata["applied_guardrails"] = [guardrail_name]
368368

369369

370-
def get_metadata_variable_name_from_kwargs(
371-
kwargs: dict
370+
def get_metadata_variable_name_from_litellm_params(
371+
litellm_params: dict
372372
) -> Literal["metadata", "litellm_metadata"]:
373373
"""
374374
Helper to return what the "metadata" field should be called in the request data
@@ -381,4 +381,4 @@ def get_metadata_variable_name_from_kwargs(
381381
- OpenAI then started using this field for their metadata
382382
- LiteLLM is now moving to using `litellm_metadata` for our metadata
383383
"""
384-
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
384+
return "litellm_metadata" if "litellm_metadata" in litellm_params else "metadata"

litellm/proxy/hooks/parallel_request_limiter_v3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from litellm.integrations.custom_logger import CustomLogger
2626
from litellm.proxy._types import UserAPIKeyAuth
2727
from litellm.types.llms.openai import BaseLiteLLMOpenAIResponseObject
28+
from fastapi import HTTPException
2829

2930
if TYPE_CHECKING:
3031
from opentelemetry.trace import Span as _Span
@@ -843,7 +844,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
843844
_get_parent_otel_span_from_kwargs,
844845
)
845846
from litellm.proxy.common_utils.callback_utils import (
846-
get_metadata_variable_name_from_kwargs,
847+
get_metadata_variable_name_from_litellm_params,
847848
get_model_group_from_litellm_kwargs,
848849
)
849850
from litellm.types.caching import RedisPipelineIncrementOperation
@@ -861,7 +862,7 @@ async def async_log_success_event(self, kwargs, response_obj, start_time, end_ti
861862

862863
# Get metadata from kwargs
863864
litellm_metadata = kwargs["litellm_params"].get(
864-
get_metadata_variable_name_from_kwargs(kwargs), {}
865+
get_metadata_variable_name_from_litellm_params(kwargs["litellm_params"]), {}
865866
)
866867
if litellm_metadata is None:
867868
return

litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py

Lines changed: 29 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,7 @@ def create_request_copy(request: Request):
5757
}
5858

5959

60-
def is_passthrough_request_using_router_model(
61-
request_body: dict, llm_router: Optional[litellm.Router]
62-
) -> bool:
60+
def is_passthrough_request_using_router_model(request_body: dict, llm_router: Optional[litellm.Router]) -> bool:
6361
"""
6462
Returns True if the model is in the llm_router model names
6563
"""
@@ -95,16 +93,12 @@ async def llm_passthrough_factory_proxy_route(
9593
model=None,
9694
)
9795
if provider_config is None:
98-
raise HTTPException(
99-
status_code=404, detail=f"Provider {custom_llm_provider} not found"
100-
)
96+
raise HTTPException(status_code=404, detail=f"Provider {custom_llm_provider} not found")
10197

10298
base_target_url = provider_config.get_api_base()
10399

104100
if base_target_url is None:
105-
raise HTTPException(
106-
status_code=404, detail=f"Provider {custom_llm_provider} api base not found"
107-
)
101+
raise HTTPException(status_code=404, detail=f"Provider {custom_llm_provider} api base not found")
108102

109103
encoded_endpoint = httpx.URL(endpoint).path
110104

@@ -183,17 +177,11 @@ async def gemini_proxy_route(
183177
[Docs](https://docs.litellm.ai/docs/pass_through/google_ai_studio)
184178
"""
185179
## CHECK FOR LITELLM API KEY IN THE QUERY PARAMS - ?..key=LITELLM_API_KEY
186-
google_ai_studio_api_key = request.query_params.get("key") or request.headers.get(
187-
"x-goog-api-key"
188-
)
180+
google_ai_studio_api_key = request.query_params.get("key") or request.headers.get("x-goog-api-key")
189181

190-
user_api_key_dict = await user_api_key_auth(
191-
request=request, api_key=f"Bearer {google_ai_studio_api_key}"
192-
)
182+
user_api_key_dict = await user_api_key_auth(request=request, api_key=f"Bearer {google_ai_studio_api_key}")
193183

194-
base_target_url = (
195-
os.getenv("GEMINI_API_BASE") or "https://generativelanguage.googleapis.com"
196-
)
184+
base_target_url = os.getenv("GEMINI_API_BASE") or "https://generativelanguage.googleapis.com"
197185
encoded_endpoint = httpx.URL(endpoint).path
198186

199187
# Ensure endpoint starts with '/' for proper URL construction
@@ -226,6 +214,7 @@ async def gemini_proxy_route(
226214
endpoint_func = create_pass_through_route(
227215
endpoint=endpoint,
228216
target=str(updated_url),
217+
custom_llm_provider="gemini",
229218
) # dynamically construct pass-through endpoint based on incoming path
230219
received_value = await endpoint_func(
231220
request,
@@ -310,9 +299,7 @@ async def vllm_proxy_route(
310299
from litellm.proxy.proxy_server import llm_router
311300

312301
request_body = await get_request_body(request)
313-
is_router_model = is_passthrough_request_using_router_model(
314-
request_body, llm_router
315-
)
302+
is_router_model = is_passthrough_request_using_router_model(request_body, llm_router)
316303
is_streaming_request = is_passthrough_request_streaming(request_body)
317304
if is_router_model and llm_router:
318305
result = cast(
@@ -327,11 +314,7 @@ async def vllm_proxy_route(
327314
content=None,
328315
data=None,
329316
files=None,
330-
json=(
331-
request_body
332-
if request.headers.get("content-type") == "application/json"
333-
else None
334-
),
317+
json=(request_body if request.headers.get("content-type") == "application/json" else None),
335318
params=None,
336319
headers=None,
337320
cookies=None,
@@ -509,9 +492,7 @@ async def handle_bedrock_count_tokens(
509492
# Extract model from request body
510493
model = request_body.get("model")
511494
if not model:
512-
raise HTTPException(
513-
status_code=400, detail={"error": "Model is required in request body"}
514-
)
495+
raise HTTPException(status_code=400, detail={"error": "Model is required in request body"})
515496

516497
# Get model parameters from router
517498
litellm_params = {"user_api_key_dict": user_api_key_dict}
@@ -550,9 +531,7 @@ async def handle_bedrock_count_tokens(
550531
raise
551532
except Exception as e:
552533
verbose_proxy_logger.error(f"Error in handle_bedrock_count_tokens: {str(e)}")
553-
raise HTTPException(
554-
status_code=500, detail={"error": f"CountTokens processing error: {str(e)}"}
555-
)
534+
raise HTTPException(status_code=500, detail={"error": f"CountTokens processing error: {str(e)}"})
556535

557536

558537
async def bedrock_llm_proxy_route(
@@ -604,8 +583,7 @@ async def bedrock_llm_proxy_route(
604583
raise HTTPException(
605584
status_code=400,
606585
detail={
607-
"error": "Model missing from endpoint. Expected format: /model/<Model>/<endpoint>. Got: "
608-
+ endpoint,
586+
"error": "Model missing from endpoint. Expected format: /model/<Model>/<endpoint>. Got: " + endpoint,
609587
},
610588
)
611589

@@ -669,9 +647,7 @@ async def bedrock_proxy_route(
669647

670648
aws_region_name = litellm.utils.get_secret(secret_name="AWS_REGION_NAME")
671649
if _is_bedrock_agent_runtime_route(endpoint=endpoint): # handle bedrock agents
672-
base_target_url = (
673-
f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
674-
)
650+
base_target_url = f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com"
675651
else:
676652
return await bedrock_llm_proxy_route(
677653
endpoint=endpoint,
@@ -701,9 +677,7 @@ async def bedrock_proxy_route(
701677
data = await request.json()
702678
except Exception as e:
703679
raise HTTPException(status_code=400, detail={"error": e})
704-
_request = AWSRequest(
705-
method="POST", url=str(updated_url), data=json.dumps(data), headers=headers
706-
)
680+
_request = AWSRequest(method="POST", url=str(updated_url), data=json.dumps(data), headers=headers)
707681
sigv4.add_auth(_request)
708682
prepped = _request.prepare()
709683

@@ -764,14 +738,8 @@ async def assemblyai_proxy_route(
764738
[Docs](https://api.assemblyai.com)
765739
"""
766740
# Set base URL based on the route
767-
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
768-
url=str(request.url)
769-
)
770-
base_target_url = (
771-
AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(
772-
region=assembly_region
773-
)
774-
)
741+
assembly_region = AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(url=str(request.url))
742+
base_target_url = AssemblyAIPassthroughLoggingHandler._get_assembly_base_url_from_region(region=assembly_region)
775743
encoded_endpoint = httpx.URL(endpoint).path
776744
# Ensure endpoint starts with '/' for proper URL construction
777745
if not encoded_endpoint.startswith("/"):
@@ -829,18 +797,14 @@ async def azure_proxy_route(
829797
"""
830798
base_target_url = get_secret_str(secret_name="AZURE_API_BASE")
831799
if base_target_url is None:
832-
raise Exception(
833-
"Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure."
834-
)
800+
raise Exception("Required 'AZURE_API_BASE' in environment to make pass-through calls to Azure.")
835801
# Add or update query parameters
836802
azure_api_key = passthrough_endpoint_router.get_credentials(
837803
custom_llm_provider=litellm.LlmProviders.AZURE.value,
838804
region_name=None,
839805
)
840806
if azure_api_key is None:
841-
raise Exception(
842-
"Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure."
843-
)
807+
raise Exception("Required 'AZURE_API_KEY' in environment to make pass-through calls to Azure.")
844808

845809
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
846810
endpoint=endpoint,
@@ -864,9 +828,7 @@ def get_default_base_target_url(vertex_location: Optional[str]) -> str:
864828

865829
@staticmethod
866830
@abstractmethod
867-
def update_base_target_url_with_credential_location(
868-
base_target_url: str, vertex_location: Optional[str]
869-
) -> str:
831+
def update_base_target_url_with_credential_location(base_target_url: str, vertex_location: Optional[str]) -> str:
870832
pass
871833

872834

@@ -876,9 +838,7 @@ def get_default_base_target_url(vertex_location: Optional[str]) -> str:
876838
return "https://discoveryengine.googleapis.com/"
877839

878840
@staticmethod
879-
def update_base_target_url_with_credential_location(
880-
base_target_url: str, vertex_location: Optional[str]
881-
) -> str:
841+
def update_base_target_url_with_credential_location(base_target_url: str, vertex_location: Optional[str]) -> str:
882842
return base_target_url
883843

884844

@@ -888,9 +848,7 @@ def get_default_base_target_url(vertex_location: Optional[str]) -> str:
888848
return get_vertex_base_url(vertex_location)
889849

890850
@staticmethod
891-
def update_base_target_url_with_credential_location(
892-
base_target_url: str, vertex_location: Optional[str]
893-
) -> str:
851+
def update_base_target_url_with_credential_location(base_target_url: str, vertex_location: Optional[str]) -> str:
894852
return get_vertex_base_url(vertex_location)
895853

896854

@@ -956,18 +914,14 @@ async def _base_vertex_proxy_route(
956914
location=vertex_location,
957915
)
958916

959-
base_target_url = get_vertex_pass_through_handler.get_default_base_target_url(
960-
vertex_location
961-
)
917+
base_target_url = get_vertex_pass_through_handler.get_default_base_target_url(vertex_location)
962918

963919
headers_passed_through = False
964920
# Use headers from the incoming request if no vertex credentials are found
965921
if vertex_credentials is None or vertex_credentials.vertex_project is None:
966922
headers = dict(request.headers) or {}
967923
headers_passed_through = True
968-
verbose_proxy_logger.debug(
969-
"default_vertex_config not set, incoming request headers %s", headers
970-
)
924+
verbose_proxy_logger.debug("default_vertex_config not set, incoming request headers %s", headers)
971925
headers.pop("content-length", None)
972926
headers.pop("host", None)
973927
else:
@@ -1133,9 +1087,7 @@ async def openai_proxy_route(
11331087
region_name=None,
11341088
)
11351089
if openai_api_key is None:
1136-
raise Exception(
1137-
"Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI."
1138-
)
1090+
raise Exception("Required 'OPENAI_API_KEY' in environment to make pass-through calls to OpenAI.")
11391091

11401092
return await BaseOpenAIPassThroughHandler._base_openai_pass_through_handler(
11411093
endpoint=endpoint,
@@ -1181,9 +1133,7 @@ async def _base_openai_pass_through_handler(
11811133
endpoint_func = create_pass_through_route(
11821134
endpoint=endpoint,
11831135
target=str(updated_url),
1184-
custom_headers=BaseOpenAIPassThroughHandler._assemble_headers(
1185-
api_key=api_key, request=request
1186-
),
1136+
custom_headers=BaseOpenAIPassThroughHandler._assemble_headers(api_key=api_key, request=request),
11871137
) # dynamically construct pass-through endpoint based on incoming path
11881138
received_value = await endpoint_func(
11891139
request,
@@ -1200,10 +1150,7 @@ def _append_openai_beta_header(headers: dict, request: Request) -> dict:
12001150
"""
12011151
Appends the OpenAI-Beta header to the headers if the request is an OpenAI Assistants API request
12021152
"""
1203-
if (
1204-
RouteChecks._is_assistants_api_request(request) is True
1205-
and "OpenAI-Beta" not in headers
1206-
):
1153+
if RouteChecks._is_assistants_api_request(request) is True and "OpenAI-Beta" not in headers:
12071154
headers["OpenAI-Beta"] = "assistants=v2"
12081155
return headers
12091156

@@ -1219,9 +1166,7 @@ def _assemble_headers(api_key: str, request: Request) -> dict:
12191166
)
12201167

12211168
@staticmethod
1222-
def _join_url_paths(
1223-
base_url: httpx.URL, path: str, custom_llm_provider: litellm.LlmProviders
1224-
) -> str:
1169+
def _join_url_paths(base_url: httpx.URL, path: str, custom_llm_provider: litellm.LlmProviders) -> str:
12251170
"""
12261171
Properly joins a base URL with a path, preserving any existing path in the base URL.
12271172
"""
@@ -1237,14 +1182,9 @@ def _join_url_paths(
12371182
joined_path_str = str(base_url.copy_with(path=full_path))
12381183

12391184
# Apply OpenAI-specific path handling for both branches
1240-
if (
1241-
custom_llm_provider == litellm.LlmProviders.OPENAI
1242-
and "/v1/" not in joined_path_str
1243-
):
1185+
if custom_llm_provider == litellm.LlmProviders.OPENAI and "/v1/" not in joined_path_str:
12441186
# Insert v1 after api.openai.com for OpenAI requests
1245-
joined_path_str = joined_path_str.replace(
1246-
"api.openai.com/", "api.openai.com/v1/"
1247-
)
1187+
joined_path_str = joined_path_str.replace("api.openai.com/", "api.openai.com/v1/")
12481188

12491189
return joined_path_str
12501190

0 commit comments

Comments
 (0)