Skip to content

Commit 0c8b311

Browse files
test: add unit testing for both flows on key unblock
1 parent 0f6898a commit 0c8b311

File tree

1 file changed

+202
-43
lines changed

1 file changed

+202
-43
lines changed

tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py

Lines changed: 202 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,9 @@ async def test_budget_reset_and_expires_at_first_of_month(monkeypatch):
183183
assert (
184184
response_date.month == expected_month
185185
), f"Expected month {expected_month}, got {response_date.month} for {key}"
186-
assert response_date.day == 1, f"Expected day 1, got {response_date.day} for {key}"
186+
assert (
187+
response_date.day == 1
188+
), f"Expected day 1, got {response_date.day} for {key}"
187189

188190

189191
@pytest.mark.asyncio
@@ -507,7 +509,6 @@ def test_get_new_token_with_invalid_key():
507509
assert "New key must start with 'sk-'" in str(exc_info.value.detail)
508510

509511

510-
511512
@pytest.mark.asyncio
512513
async def test_generate_service_account_requires_team_id():
513514
with pytest.raises(HTTPException):
@@ -529,11 +530,12 @@ async def test_generate_service_account_works_with_team_id():
529530
from unittest.mock import patch
530531

531532
# Mock the database and router dependencies from proxy_server
532-
with patch('litellm.proxy.proxy_server.prisma_client') as mock_prisma, \
533-
patch('litellm.proxy.proxy_server.llm_router') as mock_router, \
534-
patch('litellm.proxy.proxy_server.premium_user', False), \
535-
patch('litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn') as mock_generate_key:
536-
533+
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma, patch(
534+
"litellm.proxy.proxy_server.llm_router"
535+
) as mock_router, patch("litellm.proxy.proxy_server.premium_user", False), patch(
536+
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn"
537+
) as mock_generate_key:
538+
537539
# Configure mocks
538540
mock_prisma.return_value = AsyncMock()
539541
mock_router.return_value = None
@@ -542,9 +544,9 @@ async def test_generate_service_account_works_with_team_id():
542544
"key": "sk-test-key",
543545
"expires": None,
544546
"user_id": "test-user",
545-
"team_id": "IJ"
547+
"team_id": "IJ",
546548
}
547-
549+
548550
# This should not raise an exception since team_id is provided
549551
await _common_key_generation_helper(
550552
data=GenerateKeyRequest(
@@ -559,7 +561,6 @@ async def test_generate_service_account_works_with_team_id():
559561
)
560562

561563

562-
563564
@pytest.mark.asyncio
564565
async def test_update_service_account_requires_team_id():
565566
data = UpdateKeyRequest(key="sk-1", metadata={"service_account_id": "sa"})
@@ -571,7 +572,9 @@ async def test_update_service_account_requires_team_id():
571572

572573
@pytest.mark.asyncio
573574
async def test_update_service_account_works_with_team_id():
574-
data = UpdateKeyRequest(key="sk-1", metadata={"service_account_id": "sa"}, team_id="IJ")
575+
data = UpdateKeyRequest(
576+
key="sk-1", metadata={"service_account_id": "sa"}, team_id="IJ"
577+
)
575578
existing_key = LiteLLM_VerificationToken(token="hashed")
576579

577580
await prepare_key_update_data(data=data, existing_key_row=existing_key)
@@ -580,30 +583,30 @@ async def test_update_service_account_works_with_team_id():
580583
@pytest.mark.asyncio
581584
async def test_validate_team_id_used_in_service_account_request_requires_team_id():
582585
"""
583-
Test that validate_team_id_used_in_service_account_request raises HTTPException
586+
Test that validate_team_id_used_in_service_account_request raises HTTPException
584587
when team_id is None for service account key generation.
585588
"""
586589
from litellm.proxy.management_endpoints.key_management_endpoints import (
587590
validate_team_id_used_in_service_account_request,
588591
)
589-
592+
590593
mock_prisma_client = AsyncMock()
591-
594+
592595
# Test that HTTPException is raised when team_id is None
593596
with pytest.raises(HTTPException) as exc_info:
594597
await validate_team_id_used_in_service_account_request(
595598
team_id=None,
596599
prisma_client=mock_prisma_client,
597600
)
598-
601+
599602
assert exc_info.value.status_code == 400
600603
assert "team_id is required for service account keys" in str(exc_info.value.detail)
601604

602605

603606
@pytest.mark.asyncio
604607
async def test_validate_team_id_used_in_service_account_request_requires_prisma_client():
605608
"""
606-
Test that validate_team_id_used_in_service_account_request raises HTTPException
609+
Test that validate_team_id_used_in_service_account_request raises HTTPException
607610
when prisma_client is None for service account key generation.
608611
"""
609612
from litellm.proxy.management_endpoints.key_management_endpoints import (
@@ -616,78 +619,76 @@ async def test_validate_team_id_used_in_service_account_request_requires_prisma_
616619
team_id="test-team-id",
617620
prisma_client=None,
618621
)
619-
622+
620623
assert exc_info.value.status_code == 400
621-
assert "prisma_client is required for service account keys" in str(exc_info.value.detail)
624+
assert "prisma_client is required for service account keys" in str(
625+
exc_info.value.detail
626+
)
622627

623628

624629
@pytest.mark.asyncio
625630
async def test_validate_team_id_used_in_service_account_request_checks_team_exists():
626631
"""
627-
Test that validate_team_id_used_in_service_account_request validates that
632+
Test that validate_team_id_used_in_service_account_request validates that
628633
the team_id exists in the database for service account key generation.
629634
"""
630635
from litellm.proxy.management_endpoints.key_management_endpoints import (
631636
validate_team_id_used_in_service_account_request,
632637
)
633-
638+
634639
mock_prisma_client = AsyncMock()
635-
640+
636641
# Mock the database query to return None (team doesn't exist)
637642
mock_find_unique = AsyncMock(return_value=None)
638643
mock_prisma_client.db.litellm_teamtable.find_unique = mock_find_unique
639-
644+
640645
# Test that HTTPException is raised when team doesn't exist in DB
641646
with pytest.raises(HTTPException) as exc_info:
642647
await validate_team_id_used_in_service_account_request(
643648
team_id="non-existent-team-id",
644649
prisma_client=mock_prisma_client,
645650
)
646-
651+
647652
assert exc_info.value.status_code == 400
648653
assert "team_id does not exist in the database" in str(exc_info.value.detail)
649-
654+
650655
# Verify the database was queried with the correct parameters
651-
mock_find_unique.assert_called_once_with(
652-
where={"team_id": "non-existent-team-id"}
653-
)
656+
mock_find_unique.assert_called_once_with(where={"team_id": "non-existent-team-id"})
654657

655658

656659
@pytest.mark.asyncio
657660
async def test_validate_team_id_used_in_service_account_request_success():
658661
"""
659-
Test that validate_team_id_used_in_service_account_request returns True
662+
Test that validate_team_id_used_in_service_account_request returns True
660663
when team_id exists in the database for service account key generation.
661664
"""
662665
from litellm.proxy.management_endpoints.key_management_endpoints import (
663666
validate_team_id_used_in_service_account_request,
664667
)
665-
668+
666669
mock_prisma_client = AsyncMock()
667-
670+
668671
# Mock the database query to return a team object (team exists)
669672
mock_team = {"team_id": "existing-team-id", "team_name": "Test Team"}
670673
mock_find_unique = AsyncMock(return_value=mock_team)
671674
mock_prisma_client.db.litellm_teamtable.find_unique = mock_find_unique
672-
675+
673676
# Test that function returns True when team exists
674677
result = await validate_team_id_used_in_service_account_request(
675678
team_id="existing-team-id",
676679
prisma_client=mock_prisma_client,
677680
)
678-
681+
679682
assert result is True
680-
683+
681684
# Verify the database was queried with the correct parameters
682-
mock_find_unique.assert_called_once_with(
683-
where={"team_id": "existing-team-id"}
684-
)
685+
mock_find_unique.assert_called_once_with(where={"team_id": "existing-team-id"})
685686

686687

687688
@pytest.mark.asyncio
688689
async def test_generate_service_account_key_endpoint_validation():
689690
"""
690-
Test that the /key/service-account/generate endpoint properly validates
691+
Test that the /key/service-account/generate endpoint properly validates
691692
team_id requirement and team existence in database.
692693
"""
693694
from unittest.mock import patch
@@ -705,16 +706,16 @@ async def test_generate_service_account_key_endpoint_validation():
705706
),
706707
litellm_changed_by=None,
707708
)
708-
709+
709710
assert exc_info.value.status_code == 400
710711
assert "team_id is required for service account keys" in str(exc_info.value.detail)
711-
712-
# Test case 2: Team doesn't exist in database
713-
with patch('litellm.proxy.proxy_server.prisma_client') as mock_prisma:
712+
713+
# Test case 2: Team doesn't exist in database
714+
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
714715
# Mock team not found
715716
mock_find_unique = AsyncMock(return_value=None)
716717
mock_prisma.db.litellm_teamtable.find_unique = mock_find_unique
717-
718+
718719
with pytest.raises(HTTPException) as exc_info:
719720
await generate_service_account_key_fn(
720721
data=GenerateKeyRequest(team_id="non-existent-team"),
@@ -723,7 +724,165 @@ async def test_generate_service_account_key_endpoint_validation():
723724
),
724725
litellm_changed_by=None,
725726
)
726-
727+
727728
assert exc_info.value.status_code == 400
728729
assert "team_id does not exist in the database" in str(exc_info.value.detail)
729730

731+
732+
@pytest.mark.asyncio
733+
async def test_unblock_key_supports_both_sk_and_hashed_tokens(monkeypatch):
734+
"""
735+
Test that the unblock_key endpoint correctly handles both sk- prefixed tokens
736+
and hashed tokens by properly converting sk- tokens to hashed format before
737+
database operations.
738+
"""
739+
from unittest.mock import AsyncMock, MagicMock
740+
741+
from litellm.proxy._types import BlockKeyRequest
742+
from litellm.proxy.management_endpoints.key_management_endpoints import unblock_key
743+
744+
# Mock dependencies
745+
mock_prisma_client = AsyncMock()
746+
mock_user_api_key_cache = MagicMock()
747+
mock_proxy_logging_obj = MagicMock()
748+
749+
# Use a proper 64-character hex hash for testing
750+
test_hashed_token = (
751+
"a1b2c3d4e5f6789012345678901234567890123456789012345678901234abcd"
752+
)
753+
754+
# Mock the key record that will be returned from database
755+
mock_key_record = MagicMock()
756+
mock_key_record.token = test_hashed_token
757+
mock_key_record.blocked = False
758+
mock_key_record.model_dump_json.return_value = (
759+
f'{{"token": "{test_hashed_token}", "blocked": false}}'
760+
)
761+
762+
# Mock database operations
763+
mock_prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock(
764+
return_value=mock_key_record
765+
)
766+
mock_prisma_client.db.litellm_verificationtoken.update = AsyncMock(
767+
return_value=mock_key_record
768+
)
769+
770+
# Mock get_key_object and _cache_key_object functions
771+
mock_key_object = MagicMock()
772+
mock_key_object.blocked = True # Initially blocked
773+
774+
# Mock hash_token function
775+
def mock_hash_token(token):
776+
if token == "sk-test123456789":
777+
return test_hashed_token
778+
return token
779+
780+
# Apply monkeypatch
781+
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
782+
monkeypatch.setattr(
783+
"litellm.proxy.proxy_server.user_api_key_cache", mock_user_api_key_cache
784+
)
785+
monkeypatch.setattr(
786+
"litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging_obj
787+
)
788+
monkeypatch.setattr("litellm.proxy.proxy_server.hash_token", mock_hash_token)
789+
monkeypatch.setattr(
790+
"litellm.store_audit_logs", False
791+
) # Disable audit logs for simpler test
792+
793+
# Mock get_key_object and _cache_key_object
794+
async def mock_get_key_object(**kwargs):
795+
return mock_key_object
796+
797+
async def mock_cache_key_object(**kwargs):
798+
pass
799+
800+
monkeypatch.setattr(
801+
"litellm.proxy.management_endpoints.key_management_endpoints.get_key_object",
802+
mock_get_key_object,
803+
)
804+
monkeypatch.setattr(
805+
"litellm.proxy.management_endpoints.key_management_endpoints._cache_key_object",
806+
mock_cache_key_object,
807+
)
808+
809+
# Create mock request and user auth
810+
mock_request = MagicMock()
811+
user_api_key_dict = UserAPIKeyAuth(
812+
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-admin", user_id="admin_user"
813+
)
814+
815+
# Test Case 1: Using sk- prefixed token
816+
sk_token_request = BlockKeyRequest(key="sk-test123456789")
817+
818+
result = await unblock_key(
819+
data=sk_token_request,
820+
http_request=mock_request,
821+
user_api_key_dict=user_api_key_dict,
822+
litellm_changed_by=None,
823+
)
824+
825+
# Verify that the database update was called with hashed token
826+
mock_prisma_client.db.litellm_verificationtoken.update.assert_called_with(
827+
where={"token": test_hashed_token}, data={"blocked": False}
828+
)
829+
830+
assert result == mock_key_record
831+
assert mock_key_object.blocked == False # Should be updated to unblocked
832+
833+
# Reset mocks for second test
834+
mock_prisma_client.db.litellm_verificationtoken.update.reset_mock()
835+
mock_key_object.blocked = True # Reset to blocked state
836+
837+
# Test Case 2: Using already hashed token
838+
hashed_token_request = BlockKeyRequest(key=test_hashed_token)
839+
840+
result = await unblock_key(
841+
data=hashed_token_request,
842+
http_request=mock_request,
843+
user_api_key_dict=user_api_key_dict,
844+
litellm_changed_by=None,
845+
)
846+
847+
# Verify that the database update was called with the same hashed token
848+
mock_prisma_client.db.litellm_verificationtoken.update.assert_called_with(
849+
where={"token": test_hashed_token}, data={"blocked": False}
850+
)
851+
852+
assert result == mock_key_record
853+
assert mock_key_object.blocked == False # Should be updated to unblocked
854+
855+
856+
@pytest.mark.asyncio
857+
async def test_unblock_key_invalid_key_format(monkeypatch):
858+
"""
859+
Test that unblock_key properly validates key format and raises appropriate errors
860+
for invalid keys.
861+
"""
862+
from litellm.proxy._types import BlockKeyRequest
863+
from litellm.proxy.management_endpoints.key_management_endpoints import unblock_key
864+
from litellm.proxy.utils import ProxyException
865+
866+
# Mock prisma_client to avoid DB connection error
867+
mock_prisma_client = AsyncMock()
868+
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
869+
870+
# Mock request and user auth
871+
mock_request = MagicMock()
872+
user_api_key_dict = UserAPIKeyAuth(
873+
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-admin", user_id="admin_user"
874+
)
875+
876+
# Test with invalid key format
877+
invalid_key_request = BlockKeyRequest(key="invalid-key-format")
878+
879+
with pytest.raises(ProxyException) as exc_info:
880+
await unblock_key(
881+
data=invalid_key_request,
882+
http_request=mock_request,
883+
user_api_key_dict=user_api_key_dict,
884+
litellm_changed_by=None,
885+
)
886+
887+
assert exc_info.value.code == "400"
888+
assert "Invalid key format" in str(exc_info.value.message)

0 commit comments

Comments
 (0)