diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index 007c0164be40..e5141eb0951f 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -90,10 +90,10 @@ def _get_user_in_team( def _calculate_key_rotation_time(rotation_interval: str) -> datetime: """ Helper function to calculate the next rotation time for a key based on the rotation interval. - + Args: rotation_interval: String representing the rotation interval (e.g., '30d', '90d', '1h') - + Returns: datetime: The calculated next rotation time in UTC """ @@ -102,21 +102,25 @@ def _calculate_key_rotation_time(rotation_interval: str) -> datetime: return now + timedelta(seconds=interval_seconds) -def _set_key_rotation_fields(data: dict, auto_rotate: bool, rotation_interval: Optional[str]) -> None: +def _set_key_rotation_fields( + data: dict, auto_rotate: bool, rotation_interval: Optional[str] +) -> None: """ Helper function to set rotation fields in key data if auto_rotate is enabled. - + Args: data: Dictionary to update with rotation fields auto_rotate: Whether auto rotation is enabled rotation_interval: The rotation interval string (required if auto_rotate is True) """ if auto_rotate and rotation_interval: - data.update({ - "auto_rotate": auto_rotate, - "rotation_interval": rotation_interval, - "key_rotation_at": _calculate_key_rotation_time(rotation_interval) - }) + data.update( + { + "auto_rotate": auto_rotate, + "rotation_interval": rotation_interval, + "key_rotation_at": _calculate_key_rotation_time(rotation_interval), + } + ) def _is_allowed_to_make_key_request( @@ -598,9 +602,9 @@ async def _common_key_generation_helper( # noqa: PLR0915 request_type="key", **data_json, table_name="key" ) - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response + response[ + "soft_budget" + ] = data.soft_budget # include the user-input soft budget in the response response = GenerateKeyResponse(**response) @@ -934,9 +938,9 @@ async def _set_object_permission( data=data_json["object_permission"], ) ) - data_json["object_permission_id"] = ( - created_object_permission.object_permission_id - ) + data_json[ + "object_permission_id" + ] = created_object_permission.object_permission_id # delete the object_permission from the data_json data_json.pop("object_permission") return data_json @@ -946,7 +950,6 @@ async def prepare_key_update_data( data: Union[UpdateKeyRequest, RegenerateKeyRequest], existing_key_row: LiteLLM_VerificationToken, ): - data_json: dict = data.model_dump(exclude_unset=True) data_json.pop("key", None) data_json.pop("new_key", None) @@ -1198,9 +1201,9 @@ async def update_key_fn( # Handle rotation fields if auto_rotate is being enabled _set_key_rotation_fields( - non_default_values, - non_default_values.get("auto_rotate", False), - non_default_values.get("rotation_interval") + non_default_values, + non_default_values.get("auto_rotate", False), + non_default_values.get("rotation_interval"), ) _data = {**non_default_values, "token": key} @@ -1602,8 +1605,6 @@ def _check_model_access_group( return True - - async def generate_key_helper_fn( # noqa: PLR0915 request_type: Literal[ "user", "key" @@ -1676,7 +1677,8 @@ async def generate_key_helper_fn( # noqa: PLR0915 if duration is None: # allow tokens that never expire expires = None else: - expires = get_budget_reset_time(budget_duration=duration) + duration_s = duration_in_seconds(duration=duration) + expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) if key_budget_duration is None: # one-time budget key_reset_at = None @@ -1766,12 +1768,12 @@ async def generate_key_helper_fn( # noqa: PLR0915 "allowed_routes": allowed_routes or [], "object_permission_id": object_permission_id, } - + # Add rotation fields if auto_rotate is enabled _set_key_rotation_fields( data=key_data, auto_rotate=auto_rotate or False, - rotation_interval=rotation_interval + rotation_interval=rotation_interval, ) if ( @@ -1968,10 +1970,10 @@ async def delete_verification_tokens( try: if prisma_client: tokens = [_hash_token_if_needed(token=key) for key in tokens] - _keys_being_deleted: List[LiteLLM_VerificationToken] = ( - await prisma_client.db.litellm_verificationtoken.find_many( - where={"token": {"in": tokens}} - ) + _keys_being_deleted: List[ + LiteLLM_VerificationToken + ] = await prisma_client.db.litellm_verificationtoken.find_many( + where={"token": {"in": tokens}} ) if len(_keys_being_deleted) == 0: @@ -2079,9 +2081,9 @@ async def _rotate_master_key( from litellm.proxy.proxy_server import proxy_config try: - models: Optional[List] = ( - await prisma_client.db.litellm_proxymodeltable.find_many() - ) + models: Optional[ + List + ] = await prisma_client.db.litellm_proxymodeltable.find_many() except Exception: models = None # 2. process model table @@ -2392,11 +2394,11 @@ async def validate_key_list_check( param="user_id", code=status.HTTP_403_FORBIDDEN, ) - complete_user_info_db_obj: Optional[BaseModel] = ( - await prisma_client.db.litellm_usertable.find_unique( - where={"user_id": user_api_key_dict.user_id}, - include={"organization_memberships": True}, - ) + complete_user_info_db_obj: Optional[ + BaseModel + ] = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_api_key_dict.user_id}, + include={"organization_memberships": True}, ) if complete_user_info_db_obj is None: @@ -2482,10 +2484,10 @@ async def get_admin_team_ids( if complete_user_info is None: return [] # Get all teams that user is an admin of - teams: Optional[List[BaseModel]] = ( - await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": complete_user_info.teams}} - ) + teams: Optional[ + List[BaseModel] + ] = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": complete_user_info.teams}} ) if teams is None: return [] @@ -3049,7 +3051,7 @@ async def key_health( Checks: - If key based logging is configured correctly - sends a test log - Usage + Usage Pass the key in the request header diff --git a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py index e3aa7d588720..a69af2a4f4b7 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py @@ -269,8 +269,10 @@ async def test_key_token_handling(monkeypatch): @pytest.mark.asyncio async def test_budget_reset_and_expires_at_first_of_month(monkeypatch): """ - Test that when budget_duration, duration, and key_budget_duration are "1mo", budget_reset_at and expires are set to first of next month + Test that budget reset fields are standardized to 1st of next month. """ + from datetime import datetime, timezone + mock_prisma_client = AsyncMock() mock_insert_data = AsyncMock( return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None) @@ -289,52 +291,37 @@ async def test_budget_reset_and_expires_at_first_of_month(monkeypatch): return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None) ) - from datetime import datetime, timezone - - import pytest - from litellm.proxy.management_endpoints.key_management_endpoints import ( generate_key_helper_fn, ) - from litellm.proxy.proxy_server import prisma_client - # Use monkeypatch to set the prisma_client monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False) - # Test key generation with budget_duration="1mo", duration="1mo", key_budget_duration="1mo" + test_start = datetime.now(timezone.utc) + + # Generate key with monthly budget duration response = await generate_key_helper_fn( request_type="user", budget_duration="1mo", - duration="1mo", key_budget_duration="1mo", user_id="test_user", ) - print(f"response: {response}\n") - # Get the current date - now = datetime.now(timezone.utc) + # Verify budget_reset_at is standardized to 1st of next month at midnight + budget_reset_at = response.get("budget_reset_at") + assert budget_reset_at is not None + assert budget_reset_at.day == 1, "budget_reset_at should be on 1st of month" + assert budget_reset_at.hour == 0, "budget_reset_at should be at midnight" + assert budget_reset_at.minute == 0, "budget_reset_at should be at midnight" - # Calculate expected reset date (first of next month) - if now.month == 12: - expected_month = 1 - expected_year = now.year + 1 + # Verify it's next month + if test_start.month == 12: + assert budget_reset_at.month == 1 + assert budget_reset_at.year == test_start.year + 1 else: - expected_month = now.month + 1 - expected_year = now.year - - # Verify budget_reset_at, expires is set to first of next month - for key in ["budget_reset_at", "expires"]: - response_date = response.get(key) - assert response_date is not None, f"{key} not found in response" - assert ( - response_date.year == expected_year - ), f"Expected year {expected_year}, got {response_date.year} for {key}" - assert ( - response_date.month == expected_month - ), f"Expected month {expected_month}, got {response_date.month} for {key}" - assert ( - response_date.day == 1 - ), f"Expected day 1, got {response_date.day} for {key}" + assert budget_reset_at.month == test_start.month + 1 + assert budget_reset_at.year == test_start.year @pytest.mark.asyncio @@ -1040,7 +1027,7 @@ async def test_unblock_key_invalid_key_format(monkeypatch): def test_validate_key_team_change_with_member_permissions(): """ Test validate_key_team_change function with team member permissions. - + This test covers the new logic that allows team members with specific permissions to update keys, not just team admins. """ @@ -1054,111 +1041,107 @@ def test_validate_key_team_change_with_member_permissions(): mock_key.models = ["gpt-4"] mock_key.tpm_limit = None mock_key.rpm_limit = None - + mock_team = MagicMock() - mock_team.team_id = "test-team-456" + mock_team.team_id = "test-team-456" mock_team.members_with_roles = [] mock_team.tpm_limit = None mock_team.rpm_limit = None - + mock_change_initiator = MagicMock() mock_change_initiator.user_id = "test-user-123" - + mock_router = MagicMock() - + # Mock the member object returned by _get_user_in_team mock_member_object = MagicMock() - - with patch('litellm.proxy.management_endpoints.key_management_endpoints.can_team_access_model'): - with patch('litellm.proxy.management_endpoints.key_management_endpoints._get_user_in_team') as mock_get_user: - with patch('litellm.proxy.management_endpoints.key_management_endpoints._is_user_team_admin') as mock_is_admin: - with patch('litellm.proxy.management_endpoints.key_management_endpoints.TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint') as mock_has_perms: - + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.can_team_access_model" + ): + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._get_user_in_team" + ) as mock_get_user: + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints._is_user_team_admin" + ) as mock_is_admin: + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint" + ) as mock_has_perms: + mock_get_user.return_value = mock_member_object mock_is_admin.return_value = False mock_has_perms.return_value = True - + # This should not raise an exception due to member permissions validate_key_team_change( key=mock_key, team=mock_team, change_initiated_by=mock_change_initiator, - llm_router=mock_router + llm_router=mock_router, ) - + # Verify the permission check was called with correct parameters mock_has_perms.assert_called_once_with( team_member_object=mock_member_object, team_table=mock_team, - route=KeyManagementRoutes.KEY_UPDATE.value + route=KeyManagementRoutes.KEY_UPDATE.value, ) def test_key_rotation_fields_helper(): """ Test the key data update logic for rotation fields. - + This test focuses on the core logic that adds rotation fields to key_data when auto_rotate is enabled, without the complexity of full key generation. """ # Test Case 1: With rotation enabled - key_data = { - "models": ["gpt-3.5-turbo"], - "user_id": "test-user" - } - + key_data = {"models": ["gpt-3.5-turbo"], "user_id": "test-user"} + auto_rotate = True rotation_interval = "30d" - + # Simulate the rotation logic from generate_key_helper_fn if auto_rotate and rotation_interval: - key_data.update({ - "auto_rotate": auto_rotate, - "rotation_interval": rotation_interval - }) - + key_data.update( + {"auto_rotate": auto_rotate, "rotation_interval": rotation_interval} + ) + # Verify rotation fields are added assert key_data["auto_rotate"] == True assert key_data["rotation_interval"] == "30d" assert key_data["models"] == ["gpt-3.5-turbo"] # Original fields preserved - + # Test Case 2: Without rotation enabled - key_data2 = { - "models": ["gpt-4"], - "user_id": "test-user" - } - + key_data2 = {"models": ["gpt-4"], "user_id": "test-user"} + auto_rotate2 = False rotation_interval2 = None - + # Simulate the rotation logic if auto_rotate2 and rotation_interval2: - key_data2.update({ - "auto_rotate": auto_rotate2, - "rotation_interval": rotation_interval2 - }) - + key_data2.update( + {"auto_rotate": auto_rotate2, "rotation_interval": rotation_interval2} + ) + # Verify rotation fields are NOT added assert "auto_rotate" not in key_data2 assert "rotation_interval" not in key_data2 assert key_data2["models"] == ["gpt-4"] # Original fields preserved - + # Test Case 3: auto_rotate=True but no interval - key_data3 = { - "models": ["claude-3"], - "user_id": "test-user" - } - + key_data3 = {"models": ["claude-3"], "user_id": "test-user"} + auto_rotate3 = True rotation_interval3 = None - + # Simulate the rotation logic if auto_rotate3 and rotation_interval3: - key_data3.update({ - "auto_rotate": auto_rotate3, - "rotation_interval": rotation_interval3 - }) - + key_data3.update( + {"auto_rotate": auto_rotate3, "rotation_interval": rotation_interval3} + ) + # Verify rotation fields are NOT added (missing interval) assert "auto_rotate" not in key_data3 assert "rotation_interval" not in key_data3 @@ -1181,27 +1164,24 @@ async def test_update_key_fn_auto_rotate_enable(): team_id=None, auto_rotate=False, rotation_interval=None, - metadata={} + metadata={}, ) - + # Test enabling auto rotation update_request = UpdateKeyRequest( - key="test-token", - auto_rotate=True, - rotation_interval="30d" + key="test-token", auto_rotate=True, rotation_interval="30d" ) - + result = await prepare_key_update_data( - data=update_request, - existing_key_row=existing_key + data=update_request, existing_key_row=existing_key ) - + # Verify rotation fields are included assert result["auto_rotate"] is True assert result["rotation_interval"] == "30d" -@pytest.mark.asyncio +@pytest.mark.asyncio async def test_update_key_fn_auto_rotate_disable(): """Test that update_key_fn properly handles disabling auto rotation.""" from litellm.proxy._types import LiteLLM_VerificationToken, UpdateKeyRequest @@ -1218,19 +1198,222 @@ async def test_update_key_fn_auto_rotate_disable(): team_id=None, auto_rotate=True, rotation_interval="30d", - metadata={} + metadata={}, ) - + # Test disabling auto rotation - update_request = UpdateKeyRequest( - key="test-token", - auto_rotate=False - ) - + update_request = UpdateKeyRequest(key="test-token", auto_rotate=False) + result = await prepare_key_update_data( - data=update_request, - existing_key_row=existing_key + data=update_request, existing_key_row=existing_key ) - + # Verify auto_rotate is set to False assert result["auto_rotate"] is False + + +@pytest.mark.asyncio +async def test_key_expiration_calculated_from_current_time(monkeypatch): + """ + Test that key expiration is calculated as duration from current time. + + For duration="1mo" on Oct 15th: + - expires: Should be Nov 15th (1 month from creation) + - budget_reset_at: Should be Nov 1st (standardized monthly reset) + """ + from datetime import datetime, timedelta, timezone + from unittest.mock import AsyncMock, MagicMock + + from litellm.litellm_core_utils.duration_parser import duration_in_seconds + from litellm.proxy.management_endpoints.key_management_endpoints import ( + generate_key_helper_fn, + ) + + # Set up mock prisma client + mock_prisma_client = AsyncMock() + mock_insert_data = AsyncMock( + return_value=MagicMock( + token="hashed_token_123", + litellm_budget_table=None, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + ) + mock_prisma_client.insert_data = mock_insert_data + mock_prisma_client.jsonify_object = lambda data: data + + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False) + + test_start_time = datetime.now(timezone.utc) + + # Generate key with monthly duration + response = await generate_key_helper_fn( + request_type="user", + duration="1mo", + budget_duration="1mo", + user_id="test_user", + ) + + print(f"\nTest time: {test_start_time}") + print(f"expires: {response.get('expires')}") + print(f"budget_reset_at: {response.get('budget_reset_at')}") + + # Calculate expected values + duration_seconds = duration_in_seconds("1mo") + expected_expires = test_start_time + timedelta(seconds=duration_seconds) + + if test_start_time.month == 12: + expected_budget_reset = datetime( + test_start_time.year + 1, 1, 1, 0, 0, 0, tzinfo=timezone.utc + ) + else: + expected_budget_reset = datetime( + test_start_time.year, + test_start_time.month + 1, + 1, + 0, + 0, + 0, + tzinfo=timezone.utc, + ) + + # Verify budget_reset_at is standardized to 1st of next month + budget_reset_at = response.get("budget_reset_at") + assert budget_reset_at is not None + assert budget_reset_at.day == 1, "budget_reset_at should be 1st of month" + assert budget_reset_at.hour == 0 + + # Verify expires is calculated from current time + expires = response.get("expires") + assert expires is not None + + time_diff = abs((expires - expected_expires).total_seconds()) + assert ( + time_diff < 5 + ), f"expires should be 1 month from creation time. Expected: {expected_expires}, Got: {expires}" + + # expires and budget_reset_at should differ when test runs on non-1st day + if test_start_time.day != 1: + assert ( + expires != budget_reset_at + ), "expires and budget_reset_at should have different values" + assert ( + expires.day != 1 + ), f"expires should not be on 1st when created on day {test_start_time.day}" + + +@pytest.mark.asyncio +async def test_key_expiration_with_various_durations(monkeypatch): + """ + Test key expiration calculation for various duration units. + + Verify that expires is always calculated as current_time + duration. + """ + from datetime import datetime, timedelta, timezone + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy.management_endpoints.key_management_endpoints import ( + generate_key_helper_fn, + ) + + # Set up mock prisma client + mock_prisma_client = AsyncMock() + mock_insert_data = AsyncMock( + return_value=MagicMock( + token="hashed_token_123", + litellm_budget_table=None, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + ) + mock_prisma_client.insert_data = mock_insert_data + mock_prisma_client.jsonify_object = lambda data: data + + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False) + + # Test cases: (duration_string, expected_seconds_from_now) + test_cases = [ + ("30s", 30), + ("5m", 300), + ("2h", 7200), + ("7d", 604800), + ] + + for duration_str, expected_seconds in test_cases: + test_start = datetime.now(timezone.utc) + + response = await generate_key_helper_fn( + request_type="key", + duration=duration_str, + user_id="test_user", + ) + + expires = response.get("expires") + assert expires is not None, f"expires should be set for duration={duration_str}" + + # Calculate expected expiration + expected_expires = test_start + timedelta(seconds=expected_seconds) + + # Verify within 2 seconds tolerance + time_diff = abs((expires - expected_expires).total_seconds()) + assert ( + time_diff < 2 + ), f"duration={duration_str}: Expected {expected_seconds}s from now, got diff of {time_diff}s" + + +@pytest.mark.asyncio +async def test_key_budget_reset_uses_standardized_time(monkeypatch): + """ + Test that budget resets are standardized to predictable intervals. + + For budget_duration="1mo", budget should reset on 1st of next month. + """ + from datetime import datetime, timezone + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy.management_endpoints.key_management_endpoints import ( + generate_key_helper_fn, + ) + + # Set up mock prisma client + mock_prisma_client = AsyncMock() + mock_insert_data = AsyncMock( + return_value=MagicMock( + token="hashed_token_123", + litellm_budget_table=None, + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + ) + mock_prisma_client.insert_data = mock_insert_data + mock_prisma_client.jsonify_object = lambda data: data + + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False) + + test_start = datetime.now(timezone.utc) + + # Generate key with monthly budget duration + response = await generate_key_helper_fn( + request_type="user", + budget_duration="1mo", + user_id="test_user", + ) + + budget_reset_at = response.get("budget_reset_at") + assert budget_reset_at is not None + + # Verify standardized reset: 1st of next month at midnight + assert budget_reset_at.day == 1 + assert budget_reset_at.hour == 0 + assert budget_reset_at.minute == 0 + + # Verify it's next month + if test_start.month == 12: + assert budget_reset_at.month == 1 + assert budget_reset_at.year == test_start.year + 1 + else: + assert budget_reset_at.month == test_start.month + 1 + assert budget_reset_at.year == test_start.year