diff --git a/docs/my-website/docs/providers/minimax.md b/docs/my-website/docs/providers/minimax.md index 9505c26aade..d2c6565c129 100644 --- a/docs/my-website/docs/providers/minimax.md +++ b/docs/my-website/docs/providers/minimax.md @@ -11,12 +11,16 @@ Litellm provides anthropic specs compatible support for minmax ## Supported Models -MiniMax offers three models through their Anthropic-compatible API: +MiniMax offers the following models through their Anthropic-compatible API: | Model | Description | Input Cost | Output Cost | Prompt Caching Read | Prompt Caching Write | |-------|-------------|------------|-------------|---------------------|----------------------| | **MiniMax-M2.1** | Powerful Multi-Language Programming with Enhanced Programming Experience (~60 tps) | $0.3/M tokens | $1.2/M tokens | $0.03/M tokens | $0.375/M tokens | -| **MiniMax-M2.1-lightning** | Faster and More Agile (~100 tps) | $0.3/M tokens | $2.4/M tokens | $0.03/M tokens | $0.375/M tokens | +| **MiniMax-M2.1-lightning** | Deprecated model name. Use `MiniMax-M2.1-highspeed` for new integrations. | $0.3/M tokens | $2.4/M tokens | $0.03/M tokens | $0.375/M tokens | +| **MiniMax-M2.1-highspeed** | High-speed variant of MiniMax M2.1 | $0.6/M tokens | $2.4/M tokens | $0.03/M tokens | $0.375/M tokens | +| **MiniMax-M2.5** | MiniMax M2.5 general-purpose model | $0.3/M tokens | $1.2/M tokens | $0.03/M tokens | $0.375/M tokens | +| **MiniMax-M2.5-lightning** | Deprecated model name. Use `MiniMax-M2.5-highspeed` for new integrations. | $0.3/M tokens | $2.4/M tokens | $0.03/M tokens | $0.375/M tokens | +| **MiniMax-M2.5-highspeed** | High-speed variant of MiniMax M2.5 | $0.6/M tokens | $2.4/M tokens | $0.03/M tokens | $0.375/M tokens | | **MiniMax-M2** | Agentic capabilities, Advanced reasoning | $0.3/M tokens | $1.2/M tokens | $0.03/M tokens | $0.375/M tokens | diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 16752b030b9..bea789af615 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -218,6 +218,10 @@ router_settings: | enforce_user_param | boolean | If true, requires all OpenAI endpoint requests to have a 'user' param. [Doc on call hooks](call_hooks)| | reject_clientside_metadata_tags | boolean | If true, rejects requests that contain client-side 'metadata.tags' to prevent users from influencing budgets by sending different tags. Tags can only be inherited from the API key metadata. | | allowed_routes | array of strings | List of allowed proxy API routes a user can access [Doc on controlling allowed routes](enterprise#control-available-public-private-routes)| +| cors_allow_origins | Union[str, List[str]] | CORS allowlist origins for the proxy. Defaults to `["*"]` when unset. Set this to `[]` to disable CORS for all origins, or provide explicit origins to restrict access. Existing `LITELLM_CORS_*` env vars take precedence over config values. Restart the proxy after changing any CORS setting. | +| cors_allow_credentials | boolean | Allow CORS credentials. Defaults to `false` when `cors_allow_origins` is explicitly configured and this setting is unset. Otherwise it preserves the proxy's existing default behavior. Wildcard origins or patterns disable credentials. | +| cors_allow_methods | Union[str, List[str]] | CORS allowlist methods for the proxy. Defaults to `"*"` when unset. | +| cors_allow_headers | Union[str, List[str]] | CORS allowlist headers for the proxy. Defaults to `"*"` when unset. | | key_management_system | string | Specifies the key management system. [Doc Secret Managers](../secret) | | master_key | string | The master key for the proxy [Set up Virtual Keys](virtual_keys) | | database_url | string | The URL for the database connection [Set up Virtual Keys](virtual_keys) | diff --git a/litellm/integrations/anthropic_cache_control_hook.py b/litellm/integrations/anthropic_cache_control_hook.py index 0e99537d5db..1fad259d79b 100644 --- a/litellm/integrations/anthropic_cache_control_hook.py +++ b/litellm/integrations/anthropic_cache_control_hook.py @@ -98,8 +98,31 @@ def _process_message_injection( targetted_role = point.get("role", None) - # Case 1: Target by specific index - if targetted_index is not None: + # Case 1: Target by role + index (e.g., index=-1 among assistant messages) + if targetted_index is not None and targetted_role is not None: + role_indices = [ + i + for i, msg in enumerate(messages) + if msg.get("role") == targetted_role + ] + if role_indices: + try: + # Negative indices handled by Python's native list indexing (e.g., -1 = last) + actual_idx = role_indices[targetted_index] + except IndexError: + verbose_logger.warning( + f"AnthropicCacheControlHook: Index {targetted_index} is out of bounds " + f"for {len(role_indices)} messages with role '{targetted_role}'. " + f"Skipping cache control injection for this point." + ) + else: + messages[actual_idx] = ( + AnthropicCacheControlHook._safe_insert_cache_control_in_message( + messages[actual_idx], control + ) + ) + # Case 2: Target by index only + elif targetted_index is not None: original_index = targetted_index # Handle negative indices (convert to positive) if targetted_index < 0: @@ -116,7 +139,7 @@ def _process_message_injection( f"AnthropicCacheControlHook: Provided index {original_index} is out of bounds for message list of length {len(messages)}. " f"Targeted index was {targetted_index}. Skipping cache control injection for this point." ) - # Case 2: Target by role + # Case 3: Target by role only elif targetted_role is not None: for msg in messages: if msg.get("role") == targetted_role: diff --git a/litellm/integrations/arize/_utils.py b/litellm/integrations/arize/_utils.py index 8dfaa8b1425..d9af71ee80f 100644 --- a/litellm/integrations/arize/_utils.py +++ b/litellm/integrations/arize/_utils.py @@ -236,13 +236,28 @@ def _set_usage_outputs(span: "Span", response_obj, span_attrs): prompt_tokens = usage.get("prompt_tokens") or usage.get("input_tokens") if prompt_tokens: safe_set_attribute(span, span_attrs.LLM_TOKEN_COUNT_PROMPT, prompt_tokens) - reasoning_tokens = usage.get("output_tokens_details", {}).get("reasoning_tokens") - if reasoning_tokens: - safe_set_attribute( - span, - span_attrs.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING, - reasoning_tokens, - ) + completion_tokens_details = usage.get("completion_tokens_details") or usage.get( + "output_tokens_details" + ) + if completion_tokens_details is not None: + reasoning_tokens = getattr(completion_tokens_details, "reasoning_tokens", None) + if reasoning_tokens: + safe_set_attribute( + span, + span_attrs.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING, + reasoning_tokens, + ) + prompt_tokens_details = usage.get("prompt_tokens_details") or usage.get( + "input_tokens_details" + ) + if prompt_tokens_details is not None: + cached_tokens = getattr(prompt_tokens_details, "cached_tokens", None) + if cached_tokens: + safe_set_attribute( + span, + span_attrs.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, + cached_tokens, + ) def _infer_open_inference_span_kind(call_type: Optional[str]) -> str: diff --git a/litellm/litellm_core_utils/prompt_templates/factory.py b/litellm/litellm_core_utils/prompt_templates/factory.py index f6004616712..7c9255963b3 100644 --- a/litellm/litellm_core_utils/prompt_templates/factory.py +++ b/litellm/litellm_core_utils/prompt_templates/factory.py @@ -86,6 +86,17 @@ def prompt_injection_detection_default_pt(): ) # similar to autogen. Only used if `litellm.modify_params=True`. +def _get_content_as_str(content: Union[str, list, None]) -> str: + """Extract text from content that may be a string, a list of content blocks, or None.""" + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + return convert_content_list_to_str({"role": "user", "content": content}) + return "" + + def map_system_message_pt(messages: list) -> list: """ Convert 'system' message to 'user' message if provider doesn't support 'system' role. @@ -100,6 +111,7 @@ def map_system_message_pt(messages: list) -> list: new_messages = [] for i, m in enumerate(messages): if m["role"] == "system": + system_text = _get_content_as_str(m["content"]) if i < len(messages) - 1: # Not the last message next_m = messages[i + 1] next_role = next_m["role"] @@ -107,13 +119,16 @@ def map_system_message_pt(messages: list) -> list: next_role == "user" or next_role == "assistant" ): # Next message is a user or assistant message # Merge system prompt into the next message - next_m["content"] = m["content"] + " " + next_m["content"] + # Copy to avoid mutating the caller's original dict + next_m = messages[i + 1] = {**next_m} + next_text = _get_content_as_str(next_m["content"]) + next_m["content"] = " ".join(filter(None, [system_text, next_text])) elif next_role == "system": # Next message is a system message # Append a user message instead of the system message - new_message = {"role": "user", "content": m["content"]} + new_message = {"role": "user", "content": system_text} new_messages.append(new_message) else: # Last message - new_message = {"role": "user", "content": m["content"]} + new_message = {"role": "user", "content": system_text} new_messages.append(new_message) else: # Not a system message new_messages.append(m) @@ -1393,10 +1408,10 @@ def convert_to_gemini_tool_call_invoke( if tool_calls is not None: for idx, tool in enumerate(tool_calls): if "function" in tool: - gemini_function_call: Optional[ - VertexFunctionCall - ] = _gemini_tool_call_invoke_helper( - function_call_params=tool["function"] + gemini_function_call: Optional[VertexFunctionCall] = ( + _gemini_tool_call_invoke_helper( + function_call_params=tool["function"] + ) ) if gemini_function_call is not None: part_dict: VertexPartType = { @@ -1540,9 +1555,7 @@ def convert_to_gemini_tool_call_result( # noqa: PLR0915 file_data = ( file_content.get("file_data", "") if isinstance(file_content, dict) - else file_content - if isinstance(file_content, str) - else "" + else file_content if isinstance(file_content, str) else "" ) if file_data: @@ -2046,9 +2059,9 @@ def _sanitize_empty_text_content( if isinstance(content, str): if not content or not content.strip(): message = cast(AllMessageValues, dict(message)) # Make a copy - message[ - "content" - ] = "[System: Empty message content sanitised to satisfy protocol]" + message["content"] = ( + "[System: Empty message content sanitised to satisfy protocol]" + ) verbose_logger.debug( f"_sanitize_empty_text_content: Replaced empty text content in {message.get('role')} message" ) @@ -2388,9 +2401,9 @@ def anthropic_messages_pt( # noqa: PLR0915 # Convert ChatCompletionImageUrlObject to dict if needed image_url_value = m["image_url"] if isinstance(image_url_value, str): - image_url_input: Union[ - str, dict[str, Any] - ] = image_url_value + image_url_input: Union[str, dict[str, Any]] = ( + image_url_value + ) else: # ChatCompletionImageUrlObject or dict case - convert to dict image_url_input = { @@ -2417,9 +2430,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_content_element["cache_control"] = ( + _content_element["cache_control"] + ) user_content.append(_anthropic_content_element) elif m.get("type", "") == "text": m = cast(ChatCompletionTextObject, m) @@ -2479,9 +2492,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_content_text_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_content_text_element["cache_control"] = ( + _content_element["cache_control"] + ) user_content.append(_anthropic_content_text_element) @@ -2614,9 +2627,9 @@ def anthropic_messages_pt( # noqa: PLR0915 original_content_element=dict(assistant_content_block), ) if "cache_control" in _content_element: - _anthropic_text_content_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_text_content_element["cache_control"] = ( + _content_element["cache_control"] + ) text_element = _anthropic_text_content_element # Interleave: each thinking block precedes its server tool group. @@ -2776,9 +2789,9 @@ def anthropic_messages_pt( # noqa: PLR0915 ) if "cache_control" in _content_element: - _anthropic_text_content_element[ - "cache_control" - ] = _content_element["cache_control"] + _anthropic_text_content_element["cache_control"] = ( + _content_element["cache_control"] + ) assistant_content.append(_anthropic_text_content_element) @@ -5220,9 +5233,7 @@ def default_response_schema_prompt(response_schema: dict) -> str: prompt_str = """Use this JSON schema: ```json {} - ```""".format( - response_schema - ) + ```""".format(response_schema) return prompt_str diff --git a/litellm/llms/anthropic/chat/handler.py b/litellm/llms/anthropic/chat/handler.py index 9f2ddcae2c7..3d9664aaab7 100644 --- a/litellm/llms/anthropic/chat/handler.py +++ b/litellm/llms/anthropic/chat/handler.py @@ -23,6 +23,7 @@ import litellm.litellm_core_utils import litellm.types import litellm.types.utils +from litellm._logging import verbose_logger from litellm.anthropic_beta_headers_manager import ( update_request_with_filtered_beta, ) @@ -1245,7 +1246,15 @@ def convert_str_chunk_to_generic_chunk(self, chunk: str) -> ModelResponseStream: str_line = str_line[index:] if str_line.startswith("data:"): - data_json = json.loads(str_line[5:]) - return self.chunk_parser(chunk=data_json) - else: - return ModelResponseStream(id=self.response_id) + chunk_str = str_line[5:].strip() + # Models like Deepseek might return "data: [DONE]" here which is not a + # valid JSON input. We can just ignore these chunks. + try: + data_json = json.loads(chunk_str) + return self.chunk_parser(chunk=data_json) + except json.JSONDecodeError: + verbose_logger.debug( + f"Non-JSON SSE chunk received, ignoring: {chunk_str!r}" + ) + + return ModelResponseStream(id=self.response_id) diff --git a/litellm/llms/minimax/chat/transformation.py b/litellm/llms/minimax/chat/transformation.py index 4095e57a8ae..cd19f3bb2e3 100644 --- a/litellm/llms/minimax/chat/transformation.py +++ b/litellm/llms/minimax/chat/transformation.py @@ -16,9 +16,14 @@ class MinimaxChatConfig(OpenAIGPTConfig): - International: https://api.minimax.io/v1 - China: https://api.minimaxi.com/v1 + Note: MiniMax's Claude-compatible `/anthropic/v1/messages` support is implemented + separately in `litellm/llms/minimax/messages/transformation.py`. + Supported models: - MiniMax-M2.1 - - MiniMax-M2.1-lightning + - MiniMax-M2.1-highspeed + - MiniMax-M2.5 + - MiniMax-M2.5-highspeed - MiniMax-M2 """ diff --git a/litellm/llms/minimax/messages/transformation.py b/litellm/llms/minimax/messages/transformation.py index 13ed6ad3863..b97f04adc98 100644 --- a/litellm/llms/minimax/messages/transformation.py +++ b/litellm/llms/minimax/messages/transformation.py @@ -1,5 +1,8 @@ """ -MiniMax Anthropic transformation config - extends AnthropicConfig for MiniMax's Anthropic-compatible API +MiniMax Anthropic-compatible Messages API transformation config. + +MiniMax exposes Claude-compatible `/anthropic/v1/messages` endpoints separately from +its OpenAI-compatible `/v1/chat/completions` endpoint. """ from typing import Optional @@ -19,7 +22,9 @@ class MinimaxMessagesConfig(AnthropicMessagesConfig): Supported models: - MiniMax-M2.1 - - MiniMax-M2.1-lightning + - MiniMax-M2.1-highspeed + - MiniMax-M2.5 + - MiniMax-M2.5-highspeed - MiniMax-M2 """ diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 879dd42be47..262a6423e85 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -32542,6 +32542,20 @@ "supports_vision": true, "supports_web_search": true }, + "zai.glm-5": { + "input_cost_per_token": 1e-06, + "litellm_provider": "bedrock_converse", + "max_input_tokens": 200000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 3.2e-06, + "supports_function_calling": true, + "supports_reasoning": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "source": "https://aws.amazon.com/bedrock/pricing/" + }, "zai.glm-4.7": { "input_cost_per_token": 6e-07, "litellm_provider": "bedrock_converse", diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 91a953c217e..30d704d6e4e 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2236,6 +2236,22 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase): allowed_routes: Optional[List] = Field( None, description="Proxy API Endpoints you want users to be able to access" ) + cors_allow_origins: Optional[Union[str, List[str]]] = Field( + None, + description='CORS allowlist origins for the proxy. Defaults to `["*"]` when unset. Set this to `[]` to disable CORS for all origins, or provide explicit origins to restrict access. Existing `LITELLM_CORS_*` env vars take precedence over config values. Restart the proxy after changing any CORS setting.', + ) + cors_allow_credentials: Optional[bool] = Field( + None, + description="Allow CORS credentials. Defaults to False when cors_allow_origins is explicitly configured and this setting is unset. Otherwise it preserves the proxy's existing default behavior. Wildcard origins or patterns disable credentials.", + ) + cors_allow_methods: Optional[Union[str, List[str]]] = Field( + None, + description='CORS allowlist methods for the proxy. Defaults to `"*"` when unset.', + ) + cors_allow_headers: Optional[Union[str, List[str]]] = Field( + None, + description='CORS allowlist headers for the proxy. Defaults to `"*"` when unset.', + ) reject_clientside_metadata_tags: Optional[bool] = Field( None, description="When set to True, rejects requests that contain client-side 'metadata.tags' to prevent users from influencing budgets by sending different tags. Tags can only be inherited from the API key metadata.", diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index 9ecae363ed7..2a66ca72f2d 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -46,6 +46,33 @@ def initialize_callbacks_on_proxy( # noqa: PLR0915 isinstance(callback, str) and callback in litellm._known_custom_logger_compatible_callbacks ): + # Eagerly instantiate the callback class (e.g. OpenTelemetry) + # so that proxy-level globals like open_telemetry_logger are + # set at startup. Without this, store_model_in_db=true causes + # the async model-loading path to skip callback instantiation + # entirely, leaving the logger as None. + # + # Each event is registered independently so a failure in one + # does not leave the callback half-registered. + from litellm.utils import ( + _add_custom_logger_callback_to_specific_event, + ) + + for event in ("success", "failure"): + try: + _add_custom_logger_callback_to_specific_event(callback, event) + except Exception as e: + verbose_proxy_logger.error( + f"Failed to initialize callback '{callback}' " + f"for {event} event at startup: {e}. " + "Check that the required environment variables " + "are set." + ) + # Always add to imported_list so litellm.callbacks stays in + # sync — it is read by health endpoints, hot-reload, and + # spend tracking. On success the instance is already in the + # success/failure lists; the string here keeps the canonical + # config list complete. imported_list.append(callback) elif isinstance(callback, str) and callback == "presidio": from litellm.proxy.guardrails.guardrail_hooks.presidio import ( diff --git a/litellm/proxy/common_utils/key_rotation_manager.py b/litellm/proxy/common_utils/key_rotation_manager.py index 5a0a1fabc7d..3ee66bff5d5 100644 --- a/litellm/proxy/common_utils/key_rotation_manager.py +++ b/litellm/proxy/common_utils/key_rotation_manager.py @@ -5,7 +5,10 @@ """ from datetime import datetime, timezone -from typing import List +from typing import List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager from litellm._logging import verbose_proxy_logger from litellm.constants import ( @@ -30,14 +33,39 @@ class KeyRotationManager: Manages automated key rotation based on individual key rotation schedules. """ - def __init__(self, prisma_client: PrismaClient): + KEY_ROTATION_JOB_NAME = "key_rotation_job" + + def __init__( + self, + prisma_client: PrismaClient, + pod_lock_manager: Optional["PodLockManager"] = None, + ): self.prisma_client = prisma_client + self.pod_lock_manager = pod_lock_manager async def process_rotations(self): """ Main entry point - find and rotate keys that are due for rotation """ + # Acquire distributed lock to prevent concurrent rotation across pods. + # Lock TTL is DEFAULT_CRON_JOB_LOCK_TTL_SECONDS (default 60s, configurable + # via env var). For large key sets, increase this to avoid lock expiry + # mid-rotation. + lock_acquired = False try: + if self.pod_lock_manager and self.pod_lock_manager.redis_cache: + lock_acquired = ( + await self.pod_lock_manager.acquire_lock( + cronjob_id=self.KEY_ROTATION_JOB_NAME, + ) + or False + ) + if not lock_acquired: + verbose_proxy_logger.debug( + "Key rotation skipped — another pod holds the lock" + ) + return + verbose_proxy_logger.info("Starting scheduled key rotation check...") # Clean up expired deprecated keys first @@ -74,6 +102,15 @@ async def process_rotations(self): except Exception as e: verbose_proxy_logger.error(f"Key rotation process failed: {e}") + finally: + if ( + lock_acquired + and self.pod_lock_manager + and self.pod_lock_manager.redis_cache + ): + await self.pod_lock_manager.release_lock( + cronjob_id=self.KEY_ROTATION_JOB_NAME, + ) async def _find_keys_needing_rotation(self) -> List[LiteLLM_VerificationToken]: """ diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index c638e294268..186d467d19f 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -10,6 +10,7 @@ import click import httpx +import yaml from dotenv import load_dotenv import litellm @@ -58,6 +59,200 @@ def append_query_params(url: Optional[str], params: dict) -> str: return modified_url # type: ignore +def _normalize_cors_value(value: Any, setting_name: str) -> Optional[str]: + """ + Normalize a CORS config value to a comma-separated string. + + Accepts either a string or a list of strings. Returns None if the + incoming value is None. Raises ValueError for any other type or for + lists containing non-string elements. + """ + if value is None: + return None + if isinstance(value, str): + return ",".join([item.strip() for item in value.split(",") if item.strip()]) + if isinstance(value, list): + if not all(isinstance(item, str) for item in value): + raise ValueError( + f"Invalid CORS setting for '{setting_name}': expected a list of strings." + ) + return ",".join([item.strip() for item in value if item.strip()]) + raise ValueError( + f"Invalid CORS setting for '{setting_name}': expected a string or list of strings, " + f"got {type(value).__name__}." + ) + + +def _set_env_var_if_unset(env_key: str, value: Optional[str]) -> None: + # Respect explicit operator-provided env vars; config.yaml only fills gaps. + if value is not None and os.getenv(env_key) is None: + os.environ[env_key] = value + + +def _apply_cors_settings_from_general_settings(general_settings: dict) -> None: + cors_allow_origins = general_settings.get("cors_allow_origins", None) + cors_allow_credentials = general_settings.get("cors_allow_credentials", None) + cors_allow_methods = general_settings.get("cors_allow_methods", None) + cors_allow_headers = general_settings.get("cors_allow_headers", None) + + try: + normalized_origins = _normalize_cors_value( + cors_allow_origins, "cors_allow_origins" + ) + normalized_credentials = None + if cors_allow_credentials is not None: + if not isinstance(cors_allow_credentials, bool): + raise ValueError( + "Invalid CORS setting for 'cors_allow_credentials': expected " + f"a boolean, got {type(cors_allow_credentials).__name__}." + ) + normalized_credentials = str(cors_allow_credentials).lower() + normalized_methods = _normalize_cors_value( + cors_allow_methods, "cors_allow_methods" + ) + normalized_headers = _normalize_cors_value( + cors_allow_headers, "cors_allow_headers" + ) + except ValueError as e: + raise click.ClickException(f"Invalid CORS configuration: {e}") from e + + _set_env_var_if_unset("LITELLM_CORS_ALLOW_ORIGINS", normalized_origins) + if normalized_credentials is not None: + _set_env_var_if_unset( + "LITELLM_CORS_ALLOW_CREDENTIALS", + normalized_credentials, + ) + _set_env_var_if_unset("LITELLM_CORS_ALLOW_METHODS", normalized_methods) + _set_env_var_if_unset("LITELLM_CORS_ALLOW_HEADERS", normalized_headers) + + +def _process_config_includes(config: dict, base_dir: str) -> dict: + if "include" not in config: + return config + if not isinstance(config["include"], list): + raise click.ClickException("Invalid config file: 'include' must be a list.") + + for include_file in config["include"]: + file_path = os.path.join(base_dir, include_file) + if not os.path.exists(file_path): + raise click.ClickException(f"Included config file not found: {file_path}") + + with open(file_path, "r", encoding="utf-8") as file: + included_config = yaml.safe_load(file) or {} + + if not isinstance(included_config, dict): + raise click.ClickException( + f"Invalid included config file {file_path}: expected a top-level mapping." + ) + + for key, value in included_config.items(): + # Mirror ProxyConfig include behavior while avoiding list.extend() on + # non-list existing values from the main config. + if isinstance(value, list) and isinstance(config.get(key), list): + config[key].extend(value) + else: + config[key] = value + + del config["include"] + return config + + +def _resolve_os_environ_refs(config: dict) -> dict: + for key, value in config.items(): + if isinstance(value, dict): + config[key] = _resolve_os_environ_refs(value) + elif isinstance(value, list): + resolved_list = [] + for item in value: + if isinstance(item, dict): + resolved_list.append(_resolve_os_environ_refs(item)) + elif isinstance(item, str) and item.startswith("os.environ/"): + from litellm import get_secret_str + + resolved_list.append(get_secret_str(item, default_value=None)) + else: + resolved_list.append(item) + config[key] = resolved_list + elif isinstance(value, str) and value.startswith("os.environ/"): + from litellm import get_secret_str + + config[key] = get_secret_str(value, default_value=None) + return config + + +def _load_general_settings_for_early_cors( + config_file_path: Optional[str], +) -> dict: + """ + Load only the CORS-relevant general_settings before importing proxy_server. + + CORSMiddleware is configured from module-level values at import time, so the + startup path needs these settings in env vars before the first proxy_server + import. This early pass supports literal values, os.environ/ references, + and local file-path include files. Secret-manager-backed values and include + directives inside S3/GCS configs are only resolved later in startup and + therefore cannot affect the initial middleware config. + """ + config: Optional[dict] = None + + if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: + import asyncio + + from litellm.proxy.common_utils.load_config_utils import ( + get_config_file_contents_from_gcs, + get_file_contents_from_s3, + ) + + bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME") + object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY") + bucket_type = os.environ.get("LITELLM_CONFIG_BUCKET_TYPE") + + if bucket_type == "gcs": + config = asyncio.run( + get_config_file_contents_from_gcs( + bucket_name=bucket_name, object_key=object_key + ) + ) + else: + config = get_file_contents_from_s3( + bucket_name=bucket_name, object_key=object_key + ) + if config is None: + raise click.ClickException("Unable to load config from given source.") + elif config_file_path is not None: + if not os.path.exists(config_file_path): + raise click.ClickException(f"Config file not found: {config_file_path}") + + with open(config_file_path, "r", encoding="utf-8") as config_file: + config = yaml.safe_load(config_file) + + if config is None: + raise click.ClickException("Config cannot be None or empty.") + + if not isinstance(config, dict): + raise click.ClickException( + "Invalid config file: expected a top-level mapping." + ) + + config = _process_config_includes( + config=config, + base_dir=os.path.dirname(os.path.abspath(config_file_path)), + ) + else: + return {} + + if not isinstance(config, dict): + raise click.ClickException("Invalid config file: expected a top-level mapping.") + + general_settings = config.get("general_settings", {}) or {} + if not isinstance(general_settings, dict): + raise click.ClickException( + "Invalid config file: 'general_settings' must be a mapping." + ) + + return _resolve_os_environ_refs(general_settings) + + class ProxyInitializationHelpers: @staticmethod def _echo_litellm_version(): @@ -625,6 +820,12 @@ def run_server( # noqa: PLR0915 run_setup_wizard() return + if config is not None or os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: + # CORS is captured when proxy_server is first imported, so load the + # relevant general_settings before importing the app module. + early_general_settings = _load_general_settings_for_early_cors(config) + _apply_cors_settings_from_general_settings(early_general_settings) + args = locals() if local: from proxy_server import ( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index e982c934aa6..c3abc4fae59 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -1130,7 +1130,96 @@ async def openai_exception_handler(request: Request, exc: ProxyException): router = APIRouter() -origins = ["*"] + + +def _get_cors_allow_list(env_key: str) -> Optional[List[str]]: + raw_value = os.getenv(env_key) + if raw_value is None: + return None + if raw_value.strip() == "": + return [] + return [value.strip() for value in raw_value.split(",") if value.strip()] + + +def _get_cors_allow_credentials(origins_were_configured: bool) -> bool: + raw_value = os.getenv("LITELLM_CORS_ALLOW_CREDENTIALS") + if raw_value is None: + return not origins_were_configured + if raw_value.strip() == "": + return False + return raw_value.strip().lower() in {"1", "true", "yes"} + + +def _is_wildcard_cors_origin(origin: str) -> bool: + normalized_origin = origin.strip() + return ( + normalized_origin == "*" + or normalized_origin.startswith("https://*") + or normalized_origin.startswith("http://*") + ) + + +# CORSMiddleware is constructed once at startup, so changing these settings +# requires a full proxy restart to take effect. +configured_cors_allow_origins = _get_cors_allow_list("LITELLM_CORS_ALLOW_ORIGINS") +configured_cors_allow_methods = _get_cors_allow_list("LITELLM_CORS_ALLOW_METHODS") +configured_cors_allow_headers = _get_cors_allow_list("LITELLM_CORS_ALLOW_HEADERS") +cors_credentials_was_configured = ( + os.getenv("LITELLM_CORS_ALLOW_CREDENTIALS", "").strip() != "" +) +cors_origins_were_configured = configured_cors_allow_origins is not None + +cors_allow_origins = ( + ["*"] if configured_cors_allow_origins is None else configured_cors_allow_origins +) +if cors_origins_were_configured and len(cors_allow_origins) == 0: + verbose_proxy_logger.warning( + "CORS config: cors_allow_origins resolved to an empty list. " + "All cross-origin requests will be rejected. " + "Set cors_allow_origins to explicit origins or remove the setting to " + "restore the default wildcard behavior." + ) +cors_allow_credentials = _get_cors_allow_credentials( + origins_were_configured=cors_origins_were_configured +) +cors_allow_methods = ( + ["*"] if configured_cors_allow_methods is None else configured_cors_allow_methods +) +if configured_cors_allow_methods is not None and len(cors_allow_methods) == 0: + verbose_proxy_logger.warning( + "CORS config: cors_allow_methods resolved to an empty list. " + "All CORS preflight requests will be rejected. " + "Set cors_allow_methods to explicit methods or remove the setting to " + "restore the default wildcard behavior." + ) +cors_allow_headers = ( + ["*"] if configured_cors_allow_headers is None else configured_cors_allow_headers +) +if configured_cors_allow_headers is not None and len(cors_allow_headers) == 0: + verbose_proxy_logger.warning( + "CORS config: cors_allow_headers resolved to an empty list. " + "All CORS preflight requests will be rejected. " + "Set cors_allow_headers to explicit headers or remove the setting to " + "restore the default wildcard behavior." + ) + +# Preserve the proxy's existing wildcard+credentials default only when CORS +# origins are completely unconfigured. Setting credentials without origins +# still triggers this guard because origins fall back to ["*"]. +has_wildcard_origin = any( + _is_wildcard_cors_origin(origin) for origin in cors_allow_origins +) +should_validate_cors_credentials = ( + cors_origins_were_configured or cors_credentials_was_configured +) +if should_validate_cors_credentials and has_wildcard_origin and cors_allow_credentials: + verbose_proxy_logger.warning( + "CORS config rejects allow_credentials with wildcard origins or patterns " + "(including subdomain wildcards such as 'https://*.example.com'). " + "Set general_settings.cors_allow_origins to fully-qualified explicit origins " + "to enable credentials." + ) + cors_allow_credentials = False # get current directory @@ -1456,10 +1545,10 @@ def _restructure_ui_html_files(ui_root: str) -> None: app.add_middleware( CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], + allow_origins=cors_allow_origins, + allow_credentials=cors_allow_credentials, + allow_methods=cors_allow_methods, + allow_headers=cors_allow_headers, expose_headers=LITELLM_UI_ALLOW_HEADERS, ) @@ -1524,6 +1613,9 @@ async def root_redirect(): config_agents: Optional[List[AgentConfig]] = None otel_logging = False prisma_client: Optional[PrismaClient] = None +key_rotation_pod_lock_manager: Optional[ + Any +] = None # PodLockManager for key rotation distributed lock shared_aiohttp_session: Optional[ "ClientSession" ] = None # Global shared session for connection reuse @@ -2386,7 +2478,7 @@ def _process_includes(self, config: dict, base_dir: str) -> dict: included_config = self._load_yaml_file(file_path) # Simply update/extend the main config with included config for key, value in included_config.items(): - if isinstance(value, list) and key in config: + if isinstance(value, list) and isinstance(config.get(key), list): config[key].extend(value) else: config[key] = value @@ -2464,11 +2556,19 @@ def _check_for_os_environ_vars( config=value, depth=depth + 1, max_depth=max_depth ) elif isinstance(value, list): + resolved_list = [] for item in value: if isinstance(item, dict): - item = self._check_for_os_environ_vars( - config=item, depth=depth + 1, max_depth=max_depth + resolved_list.append( + self._check_for_os_environ_vars( + config=item, depth=depth + 1, max_depth=max_depth + ) ) + elif isinstance(item, str) and item.startswith("os.environ/"): + resolved_list.append(get_secret(item)) + else: + resolved_list.append(item) + config[key] = resolved_list # if the value is a string and starts with "os.environ/" - then it's an environment variable elif isinstance(value, str) and value.startswith("os.environ/"): config[key] = get_secret(value) @@ -6260,9 +6360,22 @@ async def _initialize_spend_tracking_background_jobs( ) # Get prisma_client from global scope - global prisma_client + global prisma_client, key_rotation_pod_lock_manager if prisma_client is not None: - key_rotation_manager = KeyRotationManager(prisma_client) + from litellm.proxy.db.db_transaction_queue.pod_lock_manager import ( + PodLockManager, + ) + + key_rotation_pod_lock_manager = PodLockManager( + redis_cache=litellm.cache.cache + if litellm.cache is not None + and isinstance(litellm.cache.cache, RedisCache) + else None + ) + key_rotation_manager = KeyRotationManager( + prisma_client, + pod_lock_manager=key_rotation_pod_lock_manager, + ) verbose_proxy_logger.debug( f"Key rotation background job scheduled every {LITELLM_KEY_ROTATION_CHECK_INTERVAL_SECONDS} seconds (LITELLM_KEY_ROTATION_ENABLED=true)" ) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 7954f0b6460..c3a0dd61c39 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -445,6 +445,10 @@ def update_values( self.internal_usage_cache.dual_cache.redis_cache = redis_cache self.db_spend_update_writer.redis_update_buffer.redis_cache = redis_cache self.db_spend_update_writer.pod_lock_manager.redis_cache = redis_cache + from litellm.proxy.proxy_server import key_rotation_pod_lock_manager + + if key_rotation_pod_lock_manager is not None: + key_rotation_pod_lock_manager.redis_cache = redis_cache def _add_proxy_hooks(self, llm_router: Optional[Router] = None): """ diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index bbf9f6d9dc8..faa72d67259 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -21218,6 +21218,21 @@ "max_input_tokens": 1000000, "max_output_tokens": 8192 }, + "minimax/MiniMax-M2.1-highspeed": { + "input_cost_per_token": 6e-07, + "output_cost_per_token": 2.4e-06, + "cache_read_input_token_cost": 3e-08, + "cache_creation_input_token_cost": 3.75e-07, + "litellm_provider": "minimax", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_system_messages": true, + "max_input_tokens": 1000000, + "max_output_tokens": 8192 + }, "minimax/MiniMax-M2.5": { "input_cost_per_token": 3e-07, "output_cost_per_token": 1.2e-06, @@ -21248,6 +21263,21 @@ "max_input_tokens": 1000000, "max_output_tokens": 8192 }, + "minimax/MiniMax-M2.5-highspeed": { + "input_cost_per_token": 6e-07, + "output_cost_per_token": 2.4e-06, + "cache_read_input_token_cost": 3e-08, + "cache_creation_input_token_cost": 3.75e-07, + "litellm_provider": "minimax", + "mode": "chat", + "supports_function_calling": true, + "supports_tool_choice": true, + "supports_prompt_caching": true, + "supports_reasoning": true, + "supports_system_messages": true, + "max_input_tokens": 1000000, + "max_output_tokens": 8192 + }, "minimax/MiniMax-M2": { "input_cost_per_token": 3e-07, "output_cost_per_token": 1.2e-06, @@ -32546,6 +32576,20 @@ "supports_vision": true, "supports_web_search": true }, + "zai.glm-5": { + "input_cost_per_token": 1e-06, + "litellm_provider": "bedrock_converse", + "max_input_tokens": 200000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "output_cost_per_token": 3.2e-06, + "supports_function_calling": true, + "supports_reasoning": true, + "supports_system_messages": true, + "supports_tool_choice": true, + "source": "https://aws.amazon.com/bedrock/pricing/" + }, "zai.glm-4.7": { "input_cost_per_token": 6e-07, "litellm_provider": "bedrock_converse", diff --git a/tests/litellm/integrations/__init__.py b/tests/litellm/integrations/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/litellm/integrations/test_anthropic_cache_control_hook.py b/tests/litellm/integrations/test_anthropic_cache_control_hook.py new file mode 100644 index 00000000000..7737605c4d2 --- /dev/null +++ b/tests/litellm/integrations/test_anthropic_cache_control_hook.py @@ -0,0 +1,204 @@ +""" +Unit tests for AnthropicCacheControlHook fixes: + 1. Negative string indices (e.g. "-1") are now parsed correctly (isdigit bug fix) + 2. Combined role + index filtering targets the Nth message of a given role +""" + +import pytest +from unittest.mock import patch + +from litellm.integrations.anthropic_cache_control_hook import ( + AnthropicCacheControlHook, +) +from litellm.types.llms.openai import ChatCompletionCachedContent + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_messages(): + """ + Build a realistic multi-turn conversation: + 0: system + 1: user (user #0) + 2: assistant (assistant #0) + 3: user (user #1) + 4: assistant (assistant #1) + 5: user (user #2) + """ + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "First user message"}, + {"role": "assistant", "content": "First assistant reply"}, + {"role": "user", "content": "Second user message"}, + {"role": "assistant", "content": "Second assistant reply"}, + {"role": "user", "content": "Third user message"}, + ] + + +def _has_cache_control(msg): + """Check whether a message was annotated with cache_control.""" + # String content: cache_control is a top-level key + if "cache_control" in msg: + return True + # List content: cache_control is on the last content block + content = msg.get("content") + if isinstance(content, list): + return any("cache_control" in item for item in content if isinstance(item, dict)) + return False + + +# --------------------------------------------------------------------------- +# 1. Bug fix: negative *string* indices ("-1") were rejected by isdigit() +# --------------------------------------------------------------------------- + +class TestNegativeStringIndexParsing: + """isdigit() returns False for '-1'; the fix uses try/except int().""" + + def test_negative_string_index_targets_last_message(self): + msgs = _make_messages() + point = {"location": "message", "index": "-1"} # string, not int + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + assert _has_cache_control(result[-1]), "Last message should have cache_control" + # No other message should be affected + for m in result[:-1]: + assert not _has_cache_control(m) + + def test_negative_string_index_minus_two(self): + msgs = _make_messages() + point = {"location": "message", "index": "-2"} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + assert _has_cache_control(result[-2]) + for i, m in enumerate(result): + if i != len(result) - 2: + assert not _has_cache_control(m) + + def test_positive_string_index_still_works(self): + msgs = _make_messages() + point = {"location": "message", "index": "0"} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + assert _has_cache_control(result[0]) + + def test_non_numeric_string_index_is_ignored(self): + msgs = _make_messages() + point = {"location": "message", "index": "abc"} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + # Nothing should be modified + for m in result: + assert not _has_cache_control(m) + + +# --------------------------------------------------------------------------- +# 2. New feature: combined role + index filtering (Case 1) +# --------------------------------------------------------------------------- + +class TestRolePlusIndexFiltering: + """When both role and index are set, index is relative to the role subset.""" + + def test_last_assistant_message(self): + """role=assistant, index=-1 should target the *last* assistant message.""" + msgs = _make_messages() + point = {"location": "message", "role": "assistant", "index": -1} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + # assistant messages are at indices 2 and 4; last = index 4 + assert _has_cache_control(result[4]) + for i, m in enumerate(result): + if i != 4: + assert not _has_cache_control(m) + + def test_first_user_message(self): + """role=user, index=0 should target only the first user message.""" + msgs = _make_messages() + point = {"location": "message", "role": "user", "index": 0} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + # user messages are at indices 1, 3, 5; first = index 1 + assert _has_cache_control(result[1]) + for i, m in enumerate(result): + if i != 1: + assert not _has_cache_control(m) + + def test_last_user_message(self): + """role=user, index=-1 should target only the last user message.""" + msgs = _make_messages() + point = {"location": "message", "role": "user", "index": -1} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + # user messages at 1, 3, 5; last = index 5 + assert _has_cache_control(result[5]) + for i, m in enumerate(result): + if i != 5: + assert not _has_cache_control(m) + + def test_second_assistant_message_via_index_1(self): + """role=assistant, index=1 should target the second assistant message.""" + msgs = _make_messages() + point = {"location": "message", "role": "assistant", "index": 1} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + # assistant at 2, 4; second = index 4 + assert _has_cache_control(result[4]) + for i, m in enumerate(result): + if i != 4: + assert not _has_cache_control(m) + + def test_role_plus_string_negative_index(self): + """Combined role + negative *string* index exercises both fixes together.""" + msgs = _make_messages() + point = {"location": "message", "role": "user", "index": "-1"} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + assert _has_cache_control(result[5]) + + def test_role_plus_index_out_of_bounds_logs_warning(self): + """Out-of-bounds index within the role subset should warn, not crash.""" + msgs = _make_messages() + point = {"location": "message", "role": "assistant", "index": 10} + with patch( + "litellm.integrations.anthropic_cache_control_hook.verbose_logger" + ) as mock_logger: + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + mock_logger.warning.assert_called_once() + warning_msg = mock_logger.warning.call_args[0][0] + assert "out of bounds" in warning_msg.lower() + # Nothing should be modified + for m in result: + assert not _has_cache_control(m) + + def test_role_with_no_matching_messages(self): + """If the target role has zero messages, nothing happens (no crash).""" + msgs = _make_messages() + point = {"location": "message", "role": "tool", "index": 0} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + for m in result: + assert not _has_cache_control(m) + + def test_system_role_with_index_zero(self): + """role=system, index=0 targets the single system message.""" + msgs = _make_messages() + point = {"location": "message", "role": "system", "index": 0} + result = AnthropicCacheControlHook._process_message_injection( + point=point, messages=msgs + ) + assert _has_cache_control(result[0]) + for m in result[1:]: + assert not _has_cache_control(m) diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index 56f05580cb2..c643f380c80 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -55,6 +55,90 @@ def test_supports_system_message(): assert isinstance(response, litellm.ModelResponse) +def test_supports_system_message_list_content(): + """ + Test map_system_message_pt when content is a list of content blocks + (e.g. from Anthropic pass-through endpoint). + + Fixes: https://github.com/BerriAI/litellm/issues/23757 + """ + # System message with list content (Anthropic format) + messages = [ + {"role": "system", "content": [{"type": "text", "text": "You are helpful."}]}, + {"role": "user", "content": [{"type": "text", "text": "Hello!"}]}, + ] + + new_messages = map_system_message_pt(messages=messages) + + assert len(new_messages) == 1 + assert new_messages[0]["role"] == "user" + assert isinstance(new_messages[0]["content"], str) + assert "You are helpful." in new_messages[0]["content"] + assert "Hello!" in new_messages[0]["content"] + + +def test_supports_system_message_mixed_content(): + """ + Test map_system_message_pt with mixed str and list content types. + """ + messages = [ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": [{"type": "text", "text": "User message"}]}, + ] + + new_messages = map_system_message_pt(messages=messages) + + assert len(new_messages) == 1 + assert new_messages[0]["role"] == "user" + assert isinstance(new_messages[0]["content"], str) + assert "System prompt" in new_messages[0]["content"] + assert "User message" in new_messages[0]["content"] + + +def test_supports_system_message_list_content_last_message(): + """ + Test map_system_message_pt when system message with list content is the last message. + """ + messages = [ + {"role": "system", "content": [{"type": "text", "text": "Only system"}]}, + ] + + new_messages = map_system_message_pt(messages=messages) + + assert len(new_messages) == 1 + assert new_messages[0]["role"] == "user" + assert new_messages[0]["content"] == "Only system" + + +def test_supports_system_message_none_content(): + """ + Test map_system_message_pt when next message has content=None (e.g. assistant + tool-call messages). Should not produce the literal string 'None'. + """ + messages = [ + {"role": "system", "content": "Be helpful."}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "1", + "type": "function", + "function": {"name": "f", "arguments": "{}"}, + } + ], + }, + ] + + new_messages = map_system_message_pt(messages=messages) + + assert len(new_messages) == 1 + assert new_messages[0]["role"] == "assistant" + # content should start with system text, not contain literal "None" + assert "None" not in new_messages[0]["content"] + assert "Be helpful." in new_messages[0]["content"] + + @pytest.mark.parametrize( "stop_sequence, expected_count", [("\n", 0), (["\n"], 0), (["finish_reason"], 1)] ) diff --git a/tests/proxy_unit_tests/test_cors_config.py b/tests/proxy_unit_tests/test_cors_config.py new file mode 100644 index 00000000000..c839a5dc493 --- /dev/null +++ b/tests/proxy_unit_tests/test_cors_config.py @@ -0,0 +1,515 @@ +import logging +import os +import sys +from pathlib import Path + +import pytest +from click.testing import CliRunner + +REPO_ROOT = Path(__file__).resolve().parents[2] +CORS_ENV_VARS = ( + "LITELLM_CORS_ALLOW_ORIGINS", + "LITELLM_CORS_ALLOW_CREDENTIALS", + "LITELLM_CORS_ALLOW_METHODS", + "LITELLM_CORS_ALLOW_HEADERS", +) +CORS_MODULES = ("litellm.proxy.proxy_server",) + + +@pytest.fixture(autouse=True) +def clear_cors_env(monkeypatch): + for env_var in CORS_ENV_VARS: + monkeypatch.delenv(env_var, raising=False) + yield + # CliRunner-invoked startup code writes directly to os.environ, outside of + # monkeypatch tracking. Remove any leftover values after each test so they + # cannot leak into later imports in the same pytest session. + for env_var in CORS_ENV_VARS: + os.environ.pop(env_var, None) + + +def _reload_local_proxy_server(): + if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + + for module_name in CORS_MODULES: + sys.modules.pop(module_name, None) + + import litellm.proxy.proxy_server as proxy_server + + return proxy_server + + +def test_normalize_cors_value_string(): + from litellm.proxy.proxy_cli import _normalize_cors_value + + assert ( + _normalize_cors_value( + "https://a.com, https://b.com ", "cors_allow_origins" + ) + == "https://a.com,https://b.com" + ) + + +def test_normalize_cors_value_single_string(): + from litellm.proxy.proxy_cli import _normalize_cors_value + + assert ( + _normalize_cors_value("https://a.com", "cors_allow_origins") + == "https://a.com" + ) + + +def test_normalize_cors_value_list(): + from litellm.proxy.proxy_cli import _normalize_cors_value + + assert _normalize_cors_value( + ["https://a.com", " https://b.com "], "cors_allow_origins" + ) == "https://a.com,https://b.com" + + +def test_normalize_cors_value_empty_list(): + from litellm.proxy.proxy_cli import _normalize_cors_value + + assert _normalize_cors_value([], "cors_allow_origins") == "" + + +def test_normalize_cors_value_invalid_type_raises(): + from litellm.proxy.proxy_cli import _normalize_cors_value + + with pytest.raises(ValueError, match="expected a string or list of strings"): + _normalize_cors_value(123, "cors_allow_origins") + + +def test_apply_cors_settings_preserves_existing_env(monkeypatch): + from litellm.proxy.proxy_cli import _apply_cors_settings_from_general_settings + + monkeypatch.setenv("LITELLM_CORS_ALLOW_ORIGINS", "https://env.example.com") + monkeypatch.setenv("LITELLM_CORS_ALLOW_CREDENTIALS", "false") + monkeypatch.setenv("LITELLM_CORS_ALLOW_METHODS", "PATCH") + monkeypatch.setenv("LITELLM_CORS_ALLOW_HEADERS", "X-Env") + + _apply_cors_settings_from_general_settings( + { + "cors_allow_origins": ["https://config.example.com"], + "cors_allow_credentials": True, + "cors_allow_methods": ["GET", "POST"], + "cors_allow_headers": ["Authorization"], + } + ) + + assert os.getenv("LITELLM_CORS_ALLOW_ORIGINS") == "https://env.example.com" + assert os.getenv("LITELLM_CORS_ALLOW_CREDENTIALS") == "false" + assert os.getenv("LITELLM_CORS_ALLOW_METHODS") == "PATCH" + assert os.getenv("LITELLM_CORS_ALLOW_HEADERS") == "X-Env" + + +def test_apply_cors_settings_sets_env_when_unset(monkeypatch): + from litellm.proxy.proxy_cli import _apply_cors_settings_from_general_settings + + _apply_cors_settings_from_general_settings( + { + "cors_allow_origins": ["https://config.example.com"], + "cors_allow_credentials": True, + "cors_allow_methods": ["GET", "POST"], + "cors_allow_headers": ["Authorization"], + } + ) + + assert os.getenv("LITELLM_CORS_ALLOW_ORIGINS") == "https://config.example.com" + assert os.getenv("LITELLM_CORS_ALLOW_CREDENTIALS") == "true" + assert os.getenv("LITELLM_CORS_ALLOW_METHODS") == "GET,POST" + assert os.getenv("LITELLM_CORS_ALLOW_HEADERS") == "Authorization" + + +def test_apply_cors_settings_invalid_type_raises_click_exception(): + from click import ClickException + + from litellm.proxy.proxy_cli import _apply_cors_settings_from_general_settings + + with pytest.raises(ClickException, match="Invalid CORS configuration"): + _apply_cors_settings_from_general_settings({"cors_allow_origins": 42}) + + +def test_apply_cors_settings_invalid_credentials_type_raises_click_exception(): + from click import ClickException + + from litellm.proxy.proxy_cli import _apply_cors_settings_from_general_settings + + with pytest.raises( + ClickException, + match="Invalid CORS setting for 'cors_allow_credentials'", + ): + _apply_cors_settings_from_general_settings( + {"cors_allow_credentials": [True, False]} + ) + + +def test_process_config_includes_replaces_non_list_existing_value(tmp_path): + from litellm.proxy.proxy_cli import _process_config_includes + + included_config_path = tmp_path / "extra.yaml" + included_config_path.write_text( + "\n".join( + [ + "model_list:", + " - model_name: gpt-4", + ] + ), + encoding="utf-8", + ) + + config = { + "include": [included_config_path.name], + "model_list": "some-string", + } + + processed = _process_config_includes(config, str(tmp_path)) + + assert processed == {"model_list": [{"model_name": "gpt-4"}]} + + +def test_load_general_settings_for_early_cors_handles_include_and_env_refs( + monkeypatch, tmp_path +): + from litellm.proxy.proxy_cli import _load_general_settings_for_early_cors + + monkeypatch.setenv("CORS_ORIGIN", "https://env.example.com") + + base_config_path = tmp_path / "config.yaml" + included_config_path = tmp_path / "cors.yaml" + base_config_path.write_text( + "\n".join( + [ + "include:", + f" - {included_config_path.name}", + ] + ), + encoding="utf-8", + ) + included_config_path.write_text( + "\n".join( + [ + "general_settings:", + " cors_allow_origins: os.environ/CORS_ORIGIN", + " cors_allow_credentials: false", + " cors_allow_methods:", + " - GET", + " cors_allow_headers:", + " - Authorization", + ] + ), + encoding="utf-8", + ) + + general_settings = _load_general_settings_for_early_cors(str(base_config_path)) + + assert general_settings == { + "cors_allow_origins": "https://env.example.com", + "cors_allow_credentials": False, + "cors_allow_methods": ["GET"], + "cors_allow_headers": ["Authorization"], + } + + +def test_load_general_settings_for_early_cors_resolves_env_refs_in_list( + monkeypatch, tmp_path +): + from litellm.proxy.proxy_cli import _load_general_settings_for_early_cors + + monkeypatch.setenv("CORS_ORIGIN", "https://env.example.com") + + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "general_settings:", + " cors_allow_origins:", + " - os.environ/CORS_ORIGIN", + " - https://dashboard.example.com", + ] + ), + encoding="utf-8", + ) + + general_settings = _load_general_settings_for_early_cors(str(config_path)) + + assert general_settings["cors_allow_origins"] == [ + "https://env.example.com", + "https://dashboard.example.com", + ] + + +def test_run_server_applies_config_cors_before_proxy_server_import(tmp_path): + from litellm.proxy.proxy_cli import run_server + + config_path = tmp_path / "config.yaml" + config_path.write_text( + "\n".join( + [ + "general_settings:", + " cors_allow_origins:", + " - https://cli.example.com", + " cors_allow_credentials: false", + " cors_allow_methods:", + " - GET", + " - POST", + " cors_allow_headers:", + " - Authorization", + ] + ), + encoding="utf-8", + ) + + for module_name in CORS_MODULES: + sys.modules.pop(module_name, None) + + result = CliRunner().invoke( + run_server, + [ + "--config", + str(config_path), + "--skip_server_startup", + ], + ) + + assert result.exit_code == 0, result.output + + import litellm.proxy.proxy_server as proxy_server + + assert proxy_server.cors_allow_origins == ["https://cli.example.com"] + assert proxy_server.cors_allow_credentials is False + assert proxy_server.cors_allow_methods == ["GET", "POST"] + assert proxy_server.cors_allow_headers == ["Authorization"] + + +def test_run_server_applies_included_config_cors_before_proxy_server_import(tmp_path): + from litellm.proxy.proxy_cli import run_server + + config_path = tmp_path / "config.yaml" + included_config_path = tmp_path / "extra.yaml" + config_path.write_text( + "\n".join( + [ + "include:", + f" - {included_config_path.name}", + 'model_list: "some-string"', + ] + ), + encoding="utf-8", + ) + included_config_path.write_text( + "\n".join( + [ + "model_list:", + " - model_name: gpt-4", + "general_settings:", + " cors_allow_origins:", + " - https://include.example.com", + " cors_allow_credentials: false", + " cors_allow_methods:", + " - GET", + " cors_allow_headers:", + " - Authorization", + ] + ), + encoding="utf-8", + ) + + for module_name in CORS_MODULES: + sys.modules.pop(module_name, None) + + result = CliRunner().invoke( + run_server, + [ + "--config", + str(config_path), + "--skip_server_startup", + ], + ) + + assert result.exit_code == 0, result.output + + import litellm.proxy.proxy_server as proxy_server + + assert proxy_server.cors_allow_origins == ["https://include.example.com"] + assert proxy_server.cors_allow_credentials is False + assert proxy_server.cors_allow_methods == ["GET"] + assert proxy_server.cors_allow_headers == ["Authorization"] + + +def test_cors_defaults_preserve_existing_proxy_behavior(): + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_origins == ["*"] + assert proxy_server.cors_allow_credentials is True + assert proxy_server.cors_allow_methods == ["*"] + assert proxy_server.cors_allow_headers == ["*"] + + +def test_cors_credentials_disabled_with_wildcard(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_ORIGINS", "*") + monkeypatch.setenv("LITELLM_CORS_ALLOW_CREDENTIALS", "true") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_credentials is False + assert "*" in proxy_server.cors_allow_origins + + +def test_cors_credentials_disabled_with_mixed_wildcard_origins(monkeypatch): + monkeypatch.setenv( + "LITELLM_CORS_ALLOW_ORIGINS", "*, https://dashboard.example.com" + ) + monkeypatch.setenv("LITELLM_CORS_ALLOW_CREDENTIALS", "true") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_credentials is False + assert proxy_server.cors_allow_origins == ["*", "https://dashboard.example.com"] + + +def test_path_wildcard_origin_does_not_trigger_credentials_guard(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_ORIGINS", "https://example.com/*") + monkeypatch.setenv("LITELLM_CORS_ALLOW_CREDENTIALS", "true") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_credentials is True + assert proxy_server.cors_allow_origins == ["https://example.com/*"] + + +def test_scheme_wildcard_origin_disables_credentials(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_ORIGINS", "https://*") + monkeypatch.setenv("LITELLM_CORS_ALLOW_CREDENTIALS", "true") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_credentials is False + assert proxy_server.cors_allow_origins == ["https://*"] + + +def test_cors_credentials_enabled_with_explicit_origins(monkeypatch): + monkeypatch.setenv( + "LITELLM_CORS_ALLOW_ORIGINS", "https://example.com, https://other.com" + ) + monkeypatch.setenv("LITELLM_CORS_ALLOW_CREDENTIALS", "true") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_credentials is True + assert proxy_server.cors_allow_origins == [ + "https://example.com", + "https://other.com", + ] + + +def test_cors_credentials_default_to_disabled_with_explicit_origins(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_ORIGINS", "https://example.com") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_credentials is False + assert proxy_server.cors_allow_origins == ["https://example.com"] + + +def test_cors_credentials_only_config_is_rejected_with_wildcard_default(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_CREDENTIALS", "true") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_credentials is False + assert proxy_server.cors_allow_origins == ["*"] + + +def test_explicit_empty_origins_stay_empty(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_ORIGINS", "") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_origins == [] + + +def test_explicit_empty_origins_log_warning(monkeypatch, caplog): + caplog.set_level(logging.WARNING, logger="LiteLLM Proxy") + monkeypatch.setenv("LITELLM_CORS_ALLOW_ORIGINS", "") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_origins == [] + assert any( + "cors_allow_origins resolved to an empty list" in record.message + for record in caplog.records + ) + + +def test_explicit_empty_methods_stay_empty(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_METHODS", "") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_methods == [] + + +def test_explicit_empty_methods_log_warning(monkeypatch, caplog): + caplog.set_level(logging.WARNING, logger="LiteLLM Proxy") + monkeypatch.setenv("LITELLM_CORS_ALLOW_METHODS", "") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_methods == [] + assert any( + "cors_allow_methods resolved to an empty list" in record.message + for record in caplog.records + ) + + +def test_explicit_empty_headers_stay_empty(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_HEADERS", "") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_headers == [] + + +def test_explicit_empty_headers_log_warning(monkeypatch, caplog): + caplog.set_level(logging.WARNING, logger="LiteLLM Proxy") + monkeypatch.setenv("LITELLM_CORS_ALLOW_HEADERS", "") + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_headers == [] + assert any( + "cors_allow_headers resolved to an empty list" in record.message + for record in caplog.records + ) + + +def test_methods_and_headers_are_trimmed(monkeypatch): + monkeypatch.setenv("LITELLM_CORS_ALLOW_METHODS", " GET , POST , OPTIONS ") + monkeypatch.setenv( + "LITELLM_CORS_ALLOW_HEADERS", " Authorization , Content-Type " + ) + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_methods == ["GET", "POST", "OPTIONS"] + assert proxy_server.cors_allow_headers == ["Authorization", "Content-Type"] + + +@pytest.mark.parametrize( + "credential_value,expected", + [ + ("yes", True), + ("1", True), + ("True", True), + ("FALSE", False), + ], +) +def test_cors_credentials_boolean_formats(monkeypatch, credential_value, expected): + monkeypatch.setenv("LITELLM_CORS_ALLOW_ORIGINS", "https://example.com") + monkeypatch.setenv("LITELLM_CORS_ALLOW_CREDENTIALS", credential_value) + + proxy_server = _reload_local_proxy_server() + + assert proxy_server.cors_allow_credentials is expected diff --git a/tests/test_litellm/integrations/arize/test_arize_utils.py b/tests/test_litellm/integrations/arize/test_arize_utils.py index 9a9f3d5afc7..47a420f7b7f 100644 --- a/tests/test_litellm/integrations/arize/test_arize_utils.py +++ b/tests/test_litellm/integrations/arize/test_arize_utils.py @@ -374,3 +374,127 @@ def test_construct_dynamic_arize_headers(): "arize-space-id": "test_space_key", "api_key": "test_api_key" } + + +def test_set_usage_outputs_chat_completion_tokens_details(): + """ + Test that _set_usage_outputs correctly extracts reasoning_tokens from + completion_tokens_details (Chat Completions API) and cached_tokens from + prompt_tokens_details. + """ + from unittest.mock import MagicMock + + from litellm.integrations.arize._utils import _set_usage_outputs + from litellm.types.utils import ( + CompletionTokensDetailsWrapper, + ModelResponse, + PromptTokensDetailsWrapper, + Usage, + ) + + span = MagicMock() + + response_obj = ModelResponse( + usage=Usage( + total_tokens=200, + completion_tokens=120, + prompt_tokens=80, + completion_tokens_details=CompletionTokensDetailsWrapper( + reasoning_tokens=45 + ), + prompt_tokens_details=PromptTokensDetailsWrapper(cached_tokens=30), + ), + choices=[ + Choices( + message={"role": "assistant", "content": "test"}, finish_reason="stop" + ) + ], + model="gpt-4o", + ) + + _set_usage_outputs(span, response_obj, SpanAttributes) + + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 200) + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 120) + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 80) + span.set_attribute.assert_any_call( + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING, 45 + ) + span.set_attribute.assert_any_call( + SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ, 30 + ) + + +def test_set_usage_outputs_responses_api_output_tokens_details(): + """ + Test that _set_usage_outputs falls back to output_tokens_details (Responses API) + when completion_tokens_details is not present. + """ + from unittest.mock import MagicMock + + from litellm.integrations.arize._utils import _set_usage_outputs + from litellm.types.llms.openai import ( + OutputTokensDetails, + ResponseAPIUsage, + ResponsesAPIResponse, + ) + + span = MagicMock() + + response_obj = ResponsesAPIResponse( + id="response-456", + created_at=1625247600, + output=[], + usage=ResponseAPIUsage( + input_tokens=100, + output_tokens=200, + total_tokens=300, + output_tokens_details=OutputTokensDetails(reasoning_tokens=150), + ), + ) + + _set_usage_outputs(span, response_obj, SpanAttributes) + + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 300) + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 200) + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 100) + span.set_attribute.assert_any_call( + SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING, 150 + ) + + +def test_set_usage_outputs_no_token_details(): + """ + Test that _set_usage_outputs works when neither completion_tokens_details + nor prompt_tokens_details are present (basic usage without details). + """ + from unittest.mock import MagicMock + + from litellm.integrations.arize._utils import _set_usage_outputs + from litellm.types.utils import ModelResponse, Usage + + span = MagicMock() + + response_obj = ModelResponse( + usage=Usage( + total_tokens=100, + completion_tokens=60, + prompt_tokens=40, + ), + choices=[ + Choices( + message={"role": "assistant", "content": "test"}, finish_reason="stop" + ) + ], + model="gpt-4o", + ) + + _set_usage_outputs(span, response_obj, SpanAttributes) + + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_TOTAL, 100) + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, 60) + span.set_attribute.assert_any_call(SpanAttributes.LLM_TOKEN_COUNT_PROMPT, 40) + # reasoning and cached should NOT be set + for call in span.set_attribute.call_args_list: + assert call[0][0] != SpanAttributes.LLM_TOKEN_COUNT_COMPLETION_DETAILS_REASONING + assert call[0][0] != SpanAttributes.LLM_TOKEN_COUNT_PROMPT_DETAILS_CACHE_READ diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py index 20427e8cc94..51d81ba5846 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_handler.py @@ -7,6 +7,7 @@ ChatCompletionToolCallFunctionChunk, ) from litellm.types.responses.main import OutputCodeInterpreterCall +from litellm.types.utils import StreamingChoices def test_redacted_thinking_content_block_delta(): @@ -1581,3 +1582,167 @@ def test_non_bash_tool_result_skipped(): assert ( len(code_results) == 0 ), f"Expected 0 code_interpreter_results for text_editor result, got {len(code_results)}" + + +def test_convert_str_chunk_to_generic_chunk_valid_json(): + """Test convert_str_chunk_to_generic_chunk with valid JSON SSE chunk.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # Valid SSE JSON chunk + chunk = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) + + # Should parse successfully and return a ModelResponseStream with content + assert result.choices[0].delta.content == "Hello" + assert result.id == iterator.response_id + + +def test_convert_str_chunk_to_generic_chunk_binary_input(): + """Test convert_str_chunk_to_generic_chunk with binary input.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # Binary input that needs decoding - use type ignore since function signature says str + # but implementation handles bytes + chunk = b'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) # type: ignore[arg-type] + + # Should parse successfully + assert result.choices[0].delta.content == "Hello" + assert result.id == iterator.response_id + + +def test_convert_str_chunk_to_generic_chunk_multiline_sse(): + """Test convert_str_chunk_to_generic_chunk with multiline SSE format.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # Multiline SSE with event prefix + chunk = 'event: message\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}\n\n' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) + + # Should find "data:" and parse successfully + assert result.choices[0].delta.content == "Hello" + assert result.id == iterator.response_id + + +def test_convert_str_chunk_to_generic_chunk_whitespace_handling(): + """Test convert_str_chunk_to_generic_chunk with whitespace (tests .strip()).""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # With extra whitespace around data + chunk = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}} ' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) + + # Should parse successfully despite whitespace + assert result.choices[0].delta.content == "Hello" + assert result.id == iterator.response_id + + +def test_convert_str_chunk_to_generic_chunk_find_data_in_middle(): + """Test convert_str_chunk_to_generic_chunk where data: appears in middle of string.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # data: appears after other text + chunk = 'some prefix data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) + + # Should find data: and parse successfully + assert result.choices[0].delta.content == "Hello" + assert result.id == iterator.response_id + + +def test_convert_str_chunk_to_generic_chunk_deepseek_done(): + """Test convert_str_chunk_to_generic_chunk with Deepseek-style [DONE] chunk.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # Deepseek sends "data: [DONE]" which is not valid JSON + chunk = 'data: [DONE]' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) + + # Should return empty ModelResponseStream (not crash) + assert result.id == iterator.response_id + assert len(result.choices) == 1 + assert result.choices[0] == StreamingChoices() + + +def test_convert_str_chunk_to_generic_chunk_invalid_json(): + """Test convert_str_chunk_to_generic_chunk with invalid JSON.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # Invalid JSON + chunk = 'data: {invalid json' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) + + # Should return empty ModelResponseStream (not crash) + assert result.id == iterator.response_id + assert len(result.choices) == 1 + assert result.choices[0] == StreamingChoices() + + +def test_convert_str_chunk_to_generic_chunk_missing_data_prefix(): + """Test convert_str_chunk_to_generic_chunk without data: prefix.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # No data: prefix + chunk = 'something else' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) + + # Should return empty ModelResponseStream (not crash) + assert result.id == iterator.response_id + assert len(result.choices) == 1 + assert result.choices[0] == StreamingChoices() + + +def test_convert_str_chunk_to_generic_chunk_empty_data(): + """Test convert_str_chunk_to_generic_chunk with empty data.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # Empty data + chunk = 'data:' + result = iterator.convert_str_chunk_to_generic_chunk(chunk) + + # Should return empty ModelResponseStream (not crash) + assert result.id == iterator.response_id + assert len(result.choices) == 1 + assert result.choices[0] == StreamingChoices() + + # Data with only whitespace + chunk2 = 'data: ' + result2 = iterator.convert_str_chunk_to_generic_chunk(chunk2) + # Should return empty ModelResponseStream (not crash) + assert result2.id == iterator.response_id + assert len(result2.choices) == 1 + assert result2.choices[0] == StreamingChoices() + + +def test_convert_str_chunk_to_generic_chunk_response_id_consistency(): + """Test that convert_str_chunk_to_generic_chunk returns consistent response IDs.""" + iterator = ModelResponseIterator( + streaming_response=MagicMock(), sync_stream=False, json_mode=False + ) + + # Multiple calls should return same response ID + chunk1 = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}' + chunk2 = 'data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"World"}}' + + result1 = iterator.convert_str_chunk_to_generic_chunk(chunk1) + result2 = iterator.convert_str_chunk_to_generic_chunk(chunk2) + + assert result1.id == result2.id == iterator.response_id diff --git a/tests/test_litellm/proxy/common_utils/test_callback_utils.py b/tests/test_litellm/proxy/common_utils/test_callback_utils.py index 985e8d20be7..1ab212257e3 100644 --- a/tests/test_litellm/proxy/common_utils/test_callback_utils.py +++ b/tests/test_litellm/proxy/common_utils/test_callback_utils.py @@ -79,5 +79,161 @@ def test_normalize_callback_names_none_returns_empty_list(): def test_normalize_callback_names_lowercases_strings(): - assert normalize_callback_names(["SQS", "S3", "CUSTOM_CALLBACK"]) == ["sqs", "s3", "custom_callback"] + assert normalize_callback_names(["SQS", "S3", "CUSTOM_CALLBACK"]) == [ + "sqs", + "s3", + "custom_callback", + ] + + +def test_initialize_callbacks_on_proxy_instantiates_otel(): + """ + Test that initialize_callbacks_on_proxy() actually instantiates the + OpenTelemetry callback class (not just adding the string "otel"). + + Regression test for: when store_model_in_db=true, OTEL callback was + never instantiated because initialize_callbacks_on_proxy() only added + the string "otel" to litellm.callbacks without creating the instance. + """ + import litellm + from litellm.proxy.common_utils.callback_utils import ( + initialize_callbacks_on_proxy, + ) + from litellm.integrations.opentelemetry import OpenTelemetry + from litellm.litellm_core_utils import litellm_logging + from litellm.proxy import proxy_server + + # Save original state + original_callbacks = litellm.callbacks[:] + original_success = litellm.success_callback[:] + original_async_success = litellm._async_success_callback[:] + original_failure = litellm.failure_callback[:] + original_async_failure = litellm._async_failure_callback[:] + original_otel_logger = getattr(proxy_server, "open_telemetry_logger", None) + original_in_memory_loggers = litellm_logging._in_memory_loggers[:] + + try: + # Clear callbacks + litellm.callbacks = [] + litellm.success_callback = [] + litellm._async_success_callback = [] + litellm.failure_callback = [] + litellm._async_failure_callback = [] + + initialize_callbacks_on_proxy( + value=["otel"], + premium_user=False, + config_file_path="", + litellm_settings={}, + ) + + # Verify an OpenTelemetry instance exists in success callbacks + otel_in_success = any( + isinstance(cb, OpenTelemetry) + for cb in litellm.success_callback + litellm._async_success_callback + ) + assert otel_in_success, ( + "OpenTelemetry instance not found in success callbacks. " + f"success_callback={litellm.success_callback}, " + f"_async_success_callback={litellm._async_success_callback}" + ) + + # Verify an OpenTelemetry instance exists in failure callbacks + otel_in_failure = any( + isinstance(cb, OpenTelemetry) + for cb in litellm.failure_callback + litellm._async_failure_callback + ) + assert otel_in_failure, "OpenTelemetry instance not found in failure callbacks." + + # Verify open_telemetry_logger is set on proxy_server + assert proxy_server.open_telemetry_logger is not None, ( + "proxy_server.open_telemetry_logger should be set after " + "initializing otel callback" + ) + assert isinstance(proxy_server.open_telemetry_logger, OpenTelemetry) + + finally: + # Restore ALL global state to prevent leaking into other tests + litellm.callbacks = original_callbacks + litellm.success_callback = original_success + litellm._async_success_callback = original_async_success + litellm.failure_callback = original_failure + litellm._async_failure_callback = original_async_failure + proxy_server.open_telemetry_logger = original_otel_logger + litellm_logging._in_memory_loggers = original_in_memory_loggers + + +@patch("litellm.utils._add_custom_logger_callback_to_specific_event") +def test_initialize_callbacks_on_proxy_calls_instantiation_for_known_callbacks( + mock_add_callback, +): + """ + Verify that initialize_callbacks_on_proxy calls + _add_custom_logger_callback_to_specific_event for each known callback + in the config, registering it for both success and failure events. + + This is a unit test that mocks the instantiation to avoid requiring + real env vars or creating real logger instances. + """ + from litellm.proxy.common_utils.callback_utils import ( + initialize_callbacks_on_proxy, + ) + + initialize_callbacks_on_proxy( + value=["otel"], + premium_user=False, + config_file_path="", + litellm_settings={}, + ) + + # Should be called twice: once for "success", once for "failure" + assert mock_add_callback.call_count == 2 + mock_add_callback.assert_any_call("otel", "success") + mock_add_callback.assert_any_call("otel", "failure") + + +@patch( + "litellm.utils._add_custom_logger_callback_to_specific_event", + side_effect=Exception("Missing LOGFIRE_TOKEN"), +) +def test_initialize_callbacks_on_proxy_handles_instantiation_failure( + mock_add_callback, +): + """ + Verify that if a callback fails to instantiate (e.g. missing env vars), + the proxy does not crash — it logs the error and falls back to adding + the string for deferred instantiation. + """ + import litellm + from litellm.proxy.common_utils.callback_utils import ( + initialize_callbacks_on_proxy, + ) + original_callbacks = litellm.callbacks[:] + + try: + litellm.callbacks = [] + + initialize_callbacks_on_proxy( + value=["logfire"], + premium_user=False, + config_file_path="", + litellm_settings={}, + ) + + # Verify the mock was called for both events (guards against "logfire" + # being silently removed from _known_custom_logger_compatible_callbacks). + assert mock_add_callback.call_count == 2, ( + "Expected _add_custom_logger_callback_to_specific_event to be " + f"called twice (success + failure), but call_count={mock_add_callback.call_count}" + ) + mock_add_callback.assert_any_call("logfire", "success") + mock_add_callback.assert_any_call("logfire", "failure") + + # The string should be added as a fallback so it can be retried later + assert "logfire" in litellm.callbacks, ( + "Failed callback should be added as string fallback. " + f"callbacks={litellm.callbacks}" + ) + finally: + litellm.callbacks = original_callbacks diff --git a/tests/test_litellm/proxy/common_utils/test_key_rotation_manager.py b/tests/test_litellm/proxy/common_utils/test_key_rotation_manager.py index 24828cdff36..e5f72d2de53 100644 --- a/tests/test_litellm/proxy/common_utils/test_key_rotation_manager.py +++ b/tests/test_litellm/proxy/common_utils/test_key_rotation_manager.py @@ -290,3 +290,101 @@ async def test_rotate_key_passes_grace_period(self): call_args = mock_regenerate.call_args regenerate_request = call_args[1]["data"] assert regenerate_request.grace_period == "48h" + + +class TestKeyRotationManagerDistributedLock: + """Tests for distributed lock behavior in process_rotations.""" + + @pytest.mark.asyncio + async def test_process_rotations_acquires_lock(self): + """Test that process_rotations acquires distributed lock before processing.""" + from unittest.mock import MagicMock + + mock_prisma_client = AsyncMock() + mock_prisma_client.db.litellm_verificationtoken.find_many.return_value = [] + mock_prisma_client.db.litellm_deprecatedverificationtoken.delete_many.return_value = 0 + + mock_pod_lock_manager = MagicMock() + mock_pod_lock_manager.redis_cache = MagicMock() + mock_pod_lock_manager.acquire_lock = AsyncMock(return_value=True) + mock_pod_lock_manager.release_lock = AsyncMock() + + manager = KeyRotationManager(mock_prisma_client, pod_lock_manager=mock_pod_lock_manager) + await manager.process_rotations() + + mock_pod_lock_manager.acquire_lock.assert_called_once_with( + cronjob_id=KeyRotationManager.KEY_ROTATION_JOB_NAME, + ) + + @pytest.mark.asyncio + async def test_process_rotations_skips_when_lock_held(self): + """Test that process_rotations skips when another pod holds the lock.""" + from unittest.mock import MagicMock + + mock_prisma_client = AsyncMock() + mock_pod_lock_manager = MagicMock() + mock_pod_lock_manager.redis_cache = MagicMock() + mock_pod_lock_manager.acquire_lock = AsyncMock(return_value=False) + mock_pod_lock_manager.release_lock = AsyncMock() + + manager = KeyRotationManager(mock_prisma_client, pod_lock_manager=mock_pod_lock_manager) + await manager.process_rotations() + + # Should not attempt any DB queries since lock was not acquired + mock_prisma_client.db.litellm_verificationtoken.find_many.assert_not_called() + mock_pod_lock_manager.release_lock.assert_not_called() + + @pytest.mark.asyncio + async def test_process_rotations_releases_lock_on_success(self): + """Test that lock is released after successful rotation.""" + from unittest.mock import MagicMock + + mock_prisma_client = AsyncMock() + mock_prisma_client.db.litellm_verificationtoken.find_many.return_value = [] + mock_prisma_client.db.litellm_deprecatedverificationtoken.delete_many.return_value = 0 + + mock_pod_lock_manager = MagicMock() + mock_pod_lock_manager.redis_cache = MagicMock() + mock_pod_lock_manager.acquire_lock = AsyncMock(return_value=True) + mock_pod_lock_manager.release_lock = AsyncMock() + + manager = KeyRotationManager(mock_prisma_client, pod_lock_manager=mock_pod_lock_manager) + await manager.process_rotations() + + mock_pod_lock_manager.release_lock.assert_called_once_with( + cronjob_id=KeyRotationManager.KEY_ROTATION_JOB_NAME, + ) + + @pytest.mark.asyncio + async def test_process_rotations_releases_lock_on_failure(self): + """Test that lock is released even when rotation fails.""" + from unittest.mock import MagicMock + + mock_prisma_client = AsyncMock() + mock_prisma_client.db.litellm_deprecatedverificationtoken.delete_many.side_effect = Exception("DB error") + + mock_pod_lock_manager = MagicMock() + mock_pod_lock_manager.redis_cache = MagicMock() + mock_pod_lock_manager.acquire_lock = AsyncMock(return_value=True) + mock_pod_lock_manager.release_lock = AsyncMock() + + manager = KeyRotationManager(mock_prisma_client, pod_lock_manager=mock_pod_lock_manager) + await manager.process_rotations() + + # Lock should still be released via finally block + mock_pod_lock_manager.release_lock.assert_called_once_with( + cronjob_id=KeyRotationManager.KEY_ROTATION_JOB_NAME, + ) + + @pytest.mark.asyncio + async def test_process_rotations_works_without_lock_manager(self): + """Test that process_rotations works when pod_lock_manager is None (single instance).""" + mock_prisma_client = AsyncMock() + mock_prisma_client.db.litellm_verificationtoken.find_many.return_value = [] + mock_prisma_client.db.litellm_deprecatedverificationtoken.delete_many.return_value = 0 + + manager = KeyRotationManager(mock_prisma_client, pod_lock_manager=None) + await manager.process_rotations() + + # Should still execute normally without lock + mock_prisma_client.db.litellm_deprecatedverificationtoken.delete_many.assert_called_once() diff --git a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/ModelsAndEndpointsView.tsx b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/ModelsAndEndpointsView.tsx index 514ae673d06..da76280011d 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/ModelsAndEndpointsView.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/ModelsAndEndpointsView.tsx @@ -466,7 +466,7 @@ const ModelsAndEndpointsView: React.FC = ({ premiumUser, te onAliasUpdate={setModelGroupAlias} /> - + {all_admin_roles.includes(userRole) && } )}