|
8 | 8 | import pytest
|
9 | 9 | from fastapi import Request
|
10 | 10 |
|
| 11 | +import litellm |
11 | 12 | from litellm.proxy._types import TeamCallbackMetadata, UserAPIKeyAuth
|
12 | 13 | from litellm.proxy.litellm_pre_call_utils import (
|
13 | 14 | KeyAndTeamLoggingSettings,
|
@@ -935,3 +936,126 @@ def test_add_headers_to_llm_call_by_model_group_existing_headers_in_data():
|
935 | 936 | finally:
|
936 | 937 | # Restore original model_group_settings
|
937 | 938 | litellm.model_group_settings = original_model_group_settings
|
| 939 | + |
| 940 | +import json |
| 941 | +import time |
| 942 | +from typing import Optional |
| 943 | +from unittest.mock import AsyncMock |
| 944 | + |
| 945 | +from fastapi.responses import Response |
| 946 | + |
| 947 | +from litellm.integrations.custom_logger import CustomLogger |
| 948 | +from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing |
| 949 | +from litellm.proxy.utils import ProxyLogging |
| 950 | +from litellm.types.utils import StandardLoggingPayload |
| 951 | + |
| 952 | + |
| 953 | +class TestCustomLogger(CustomLogger): |
| 954 | + def __init__(self): |
| 955 | + self.standard_logging_object: Optional[StandardLoggingPayload] = None |
| 956 | + super().__init__() |
| 957 | + |
| 958 | + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): |
| 959 | + print(f"SUCCESS CALLBACK CALLED! kwargs keys: {list(kwargs.keys())}") |
| 960 | + self.standard_logging_object = kwargs.get("standard_logging_object") |
| 961 | + print(f"Captured standard_logging_object: {self.standard_logging_object}") |
| 962 | + |
| 963 | + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): |
| 964 | + print(f"FAILURE CALLBACK CALLED! kwargs keys: {list(kwargs.keys())}") |
| 965 | + |
| 966 | +@pytest.mark.asyncio |
| 967 | +async def test_add_litellm_metadata_from_request_headers(): |
| 968 | + """ |
| 969 | + Test that add_litellm_metadata_from_request_headers properly adds litellm metadata from request headers, |
| 970 | + makes an LLM request using base_process_llm_request, sleeps for 3 seconds, and checks standard_logging_payload has spend_logs_metadata from headers |
| 971 | +
|
| 972 | + Relevant issue: https://github.com/BerriAI/litellm/issues/14008 |
| 973 | + """ |
| 974 | + # Set up test logger |
| 975 | + litellm._turn_on_debug() |
| 976 | + test_logger = TestCustomLogger() |
| 977 | + litellm.callbacks = [test_logger] |
| 978 | + |
| 979 | + # Prepare test data (ensure no streaming, add mock_response and api_key to route to litellm.acompletion) |
| 980 | + headers = {"x-litellm-spend-logs-metadata": '{"user_id": "12345", "project_id": "proj_abc", "request_type": "chat_completion", "timestamp": "2025-09-02T10:30:00Z"}'} |
| 981 | + data = {"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}], "stream": False, "mock_response": "Hi", "api_key": "fake-key"} |
| 982 | + |
| 983 | + # Create mock request with headers |
| 984 | + mock_request = MagicMock(spec=Request) |
| 985 | + mock_request.headers = headers |
| 986 | + mock_request.url.path = "/chat/completions" |
| 987 | + |
| 988 | + # Create mock response |
| 989 | + mock_fastapi_response = MagicMock(spec=Response) |
| 990 | + |
| 991 | + # Create mock user API key dict |
| 992 | + mock_user_api_key_dict = UserAPIKeyAuth( |
| 993 | + api_key="test-key", |
| 994 | + user_id="test-user", |
| 995 | + org_id="test-org" |
| 996 | + ) |
| 997 | + |
| 998 | + # Create mock proxy logging object |
| 999 | + mock_proxy_logging_obj = MagicMock(spec=ProxyLogging) |
| 1000 | + |
| 1001 | + # Create async functions for the hooks |
| 1002 | + async def mock_during_call_hook(*args, **kwargs): |
| 1003 | + return None |
| 1004 | + |
| 1005 | + async def mock_pre_call_hook(*args, **kwargs): |
| 1006 | + return data |
| 1007 | + |
| 1008 | + async def mock_post_call_success_hook(*args, **kwargs): |
| 1009 | + # Return the response unchanged |
| 1010 | + return kwargs.get('response', args[2] if len(args) > 2 else None) |
| 1011 | + |
| 1012 | + mock_proxy_logging_obj.during_call_hook = mock_during_call_hook |
| 1013 | + mock_proxy_logging_obj.pre_call_hook = mock_pre_call_hook |
| 1014 | + mock_proxy_logging_obj.post_call_success_hook = mock_post_call_success_hook |
| 1015 | + |
| 1016 | + # Create mock proxy config |
| 1017 | + mock_proxy_config = MagicMock() |
| 1018 | + |
| 1019 | + # Create mock general settings |
| 1020 | + general_settings = {} |
| 1021 | + |
| 1022 | + # Create mock select_data_generator with correct signature |
| 1023 | + def mock_select_data_generator(response=None, user_api_key_dict=None, request_data=None): |
| 1024 | + async def mock_generator(): |
| 1025 | + yield "data: " + json.dumps({"choices": [{"delta": {"content": "Hello"}}]}) + "\n\n" |
| 1026 | + yield "data: [DONE]\n\n" |
| 1027 | + return mock_generator() |
| 1028 | + |
| 1029 | + # Create the processor |
| 1030 | + processor = ProxyBaseLLMRequestProcessing(data=data) |
| 1031 | + |
| 1032 | + # Call base_process_llm_request (it will use the mock_response="Hi" parameter) |
| 1033 | + result = await processor.base_process_llm_request( |
| 1034 | + request=mock_request, |
| 1035 | + fastapi_response=mock_fastapi_response, |
| 1036 | + user_api_key_dict=mock_user_api_key_dict, |
| 1037 | + route_type="acompletion", |
| 1038 | + proxy_logging_obj=mock_proxy_logging_obj, |
| 1039 | + general_settings=general_settings, |
| 1040 | + proxy_config=mock_proxy_config, |
| 1041 | + select_data_generator=mock_select_data_generator, |
| 1042 | + llm_router=None, |
| 1043 | + model="gpt-4", |
| 1044 | + is_streaming_request=False |
| 1045 | + ) |
| 1046 | + |
| 1047 | + # Sleep for 3 seconds to allow logging to complete |
| 1048 | + await asyncio.sleep(3) |
| 1049 | + |
| 1050 | + # Check if standard_logging_object was set |
| 1051 | + assert test_logger.standard_logging_object is not None, "standard_logging_object should be populated after LLM request" |
| 1052 | + |
| 1053 | + # Verify the logging object contains expected metadata |
| 1054 | + standard_logging_obj = test_logger.standard_logging_object |
| 1055 | + |
| 1056 | + print(f"Standard logging object captured: {json.dumps(standard_logging_obj, indent=4, default=str)}") |
| 1057 | + |
| 1058 | + SPEND_LOGS_METADATA = standard_logging_obj["metadata"]["spend_logs_metadata"] |
| 1059 | + assert SPEND_LOGS_METADATA == dict(json.loads(headers["x-litellm-spend-logs-metadata"])), "spend_logs_metadata should be the same as the headers" |
| 1060 | + |
| 1061 | + |
0 commit comments