@@ -183,7 +183,9 @@ async def test_budget_reset_and_expires_at_first_of_month(monkeypatch):
183
183
assert (
184
184
response_date .month == expected_month
185
185
), f"Expected month { expected_month } , got { response_date .month } for { key } "
186
- assert response_date .day == 1 , f"Expected day 1, got { response_date .day } for { key } "
186
+ assert (
187
+ response_date .day == 1
188
+ ), f"Expected day 1, got { response_date .day } for { key } "
187
189
188
190
189
191
@pytest .mark .asyncio
@@ -507,7 +509,6 @@ def test_get_new_token_with_invalid_key():
507
509
assert "New key must start with 'sk-'" in str (exc_info .value .detail )
508
510
509
511
510
-
511
512
@pytest .mark .asyncio
512
513
async def test_generate_service_account_requires_team_id ():
513
514
with pytest .raises (HTTPException ):
@@ -529,11 +530,12 @@ async def test_generate_service_account_works_with_team_id():
529
530
from unittest .mock import patch
530
531
531
532
# Mock the database and router dependencies from proxy_server
532
- with patch ('litellm.proxy.proxy_server.prisma_client' ) as mock_prisma , \
533
- patch ('litellm.proxy.proxy_server.llm_router' ) as mock_router , \
534
- patch ('litellm.proxy.proxy_server.premium_user' , False ), \
535
- patch ('litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn' ) as mock_generate_key :
536
-
533
+ with patch ("litellm.proxy.proxy_server.prisma_client" ) as mock_prisma , patch (
534
+ "litellm.proxy.proxy_server.llm_router"
535
+ ) as mock_router , patch ("litellm.proxy.proxy_server.premium_user" , False ), patch (
536
+ "litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn"
537
+ ) as mock_generate_key :
538
+
537
539
# Configure mocks
538
540
mock_prisma .return_value = AsyncMock ()
539
541
mock_router .return_value = None
@@ -542,9 +544,9 @@ async def test_generate_service_account_works_with_team_id():
542
544
"key" : "sk-test-key" ,
543
545
"expires" : None ,
544
546
"user_id" : "test-user" ,
545
- "team_id" : "IJ"
547
+ "team_id" : "IJ" ,
546
548
}
547
-
549
+
548
550
# This should not raise an exception since team_id is provided
549
551
await _common_key_generation_helper (
550
552
data = GenerateKeyRequest (
@@ -559,7 +561,6 @@ async def test_generate_service_account_works_with_team_id():
559
561
)
560
562
561
563
562
-
563
564
@pytest .mark .asyncio
564
565
async def test_update_service_account_requires_team_id ():
565
566
data = UpdateKeyRequest (key = "sk-1" , metadata = {"service_account_id" : "sa" })
@@ -571,7 +572,9 @@ async def test_update_service_account_requires_team_id():
571
572
572
573
@pytest .mark .asyncio
573
574
async def test_update_service_account_works_with_team_id ():
574
- data = UpdateKeyRequest (key = "sk-1" , metadata = {"service_account_id" : "sa" }, team_id = "IJ" )
575
+ data = UpdateKeyRequest (
576
+ key = "sk-1" , metadata = {"service_account_id" : "sa" }, team_id = "IJ"
577
+ )
575
578
existing_key = LiteLLM_VerificationToken (token = "hashed" )
576
579
577
580
await prepare_key_update_data (data = data , existing_key_row = existing_key )
@@ -580,30 +583,30 @@ async def test_update_service_account_works_with_team_id():
580
583
@pytest .mark .asyncio
581
584
async def test_validate_team_id_used_in_service_account_request_requires_team_id ():
582
585
"""
583
- Test that validate_team_id_used_in_service_account_request raises HTTPException
586
+ Test that validate_team_id_used_in_service_account_request raises HTTPException
584
587
when team_id is None for service account key generation.
585
588
"""
586
589
from litellm .proxy .management_endpoints .key_management_endpoints import (
587
590
validate_team_id_used_in_service_account_request ,
588
591
)
589
-
592
+
590
593
mock_prisma_client = AsyncMock ()
591
-
594
+
592
595
# Test that HTTPException is raised when team_id is None
593
596
with pytest .raises (HTTPException ) as exc_info :
594
597
await validate_team_id_used_in_service_account_request (
595
598
team_id = None ,
596
599
prisma_client = mock_prisma_client ,
597
600
)
598
-
601
+
599
602
assert exc_info .value .status_code == 400
600
603
assert "team_id is required for service account keys" in str (exc_info .value .detail )
601
604
602
605
603
606
@pytest .mark .asyncio
604
607
async def test_validate_team_id_used_in_service_account_request_requires_prisma_client ():
605
608
"""
606
- Test that validate_team_id_used_in_service_account_request raises HTTPException
609
+ Test that validate_team_id_used_in_service_account_request raises HTTPException
607
610
when prisma_client is None for service account key generation.
608
611
"""
609
612
from litellm .proxy .management_endpoints .key_management_endpoints import (
@@ -616,78 +619,76 @@ async def test_validate_team_id_used_in_service_account_request_requires_prisma_
616
619
team_id = "test-team-id" ,
617
620
prisma_client = None ,
618
621
)
619
-
622
+
620
623
assert exc_info .value .status_code == 400
621
- assert "prisma_client is required for service account keys" in str (exc_info .value .detail )
624
+ assert "prisma_client is required for service account keys" in str (
625
+ exc_info .value .detail
626
+ )
622
627
623
628
624
629
@pytest .mark .asyncio
625
630
async def test_validate_team_id_used_in_service_account_request_checks_team_exists ():
626
631
"""
627
- Test that validate_team_id_used_in_service_account_request validates that
632
+ Test that validate_team_id_used_in_service_account_request validates that
628
633
the team_id exists in the database for service account key generation.
629
634
"""
630
635
from litellm .proxy .management_endpoints .key_management_endpoints import (
631
636
validate_team_id_used_in_service_account_request ,
632
637
)
633
-
638
+
634
639
mock_prisma_client = AsyncMock ()
635
-
640
+
636
641
# Mock the database query to return None (team doesn't exist)
637
642
mock_find_unique = AsyncMock (return_value = None )
638
643
mock_prisma_client .db .litellm_teamtable .find_unique = mock_find_unique
639
-
644
+
640
645
# Test that HTTPException is raised when team doesn't exist in DB
641
646
with pytest .raises (HTTPException ) as exc_info :
642
647
await validate_team_id_used_in_service_account_request (
643
648
team_id = "non-existent-team-id" ,
644
649
prisma_client = mock_prisma_client ,
645
650
)
646
-
651
+
647
652
assert exc_info .value .status_code == 400
648
653
assert "team_id does not exist in the database" in str (exc_info .value .detail )
649
-
654
+
650
655
# Verify the database was queried with the correct parameters
651
- mock_find_unique .assert_called_once_with (
652
- where = {"team_id" : "non-existent-team-id" }
653
- )
656
+ mock_find_unique .assert_called_once_with (where = {"team_id" : "non-existent-team-id" })
654
657
655
658
656
659
@pytest .mark .asyncio
657
660
async def test_validate_team_id_used_in_service_account_request_success ():
658
661
"""
659
- Test that validate_team_id_used_in_service_account_request returns True
662
+ Test that validate_team_id_used_in_service_account_request returns True
660
663
when team_id exists in the database for service account key generation.
661
664
"""
662
665
from litellm .proxy .management_endpoints .key_management_endpoints import (
663
666
validate_team_id_used_in_service_account_request ,
664
667
)
665
-
668
+
666
669
mock_prisma_client = AsyncMock ()
667
-
670
+
668
671
# Mock the database query to return a team object (team exists)
669
672
mock_team = {"team_id" : "existing-team-id" , "team_name" : "Test Team" }
670
673
mock_find_unique = AsyncMock (return_value = mock_team )
671
674
mock_prisma_client .db .litellm_teamtable .find_unique = mock_find_unique
672
-
675
+
673
676
# Test that function returns True when team exists
674
677
result = await validate_team_id_used_in_service_account_request (
675
678
team_id = "existing-team-id" ,
676
679
prisma_client = mock_prisma_client ,
677
680
)
678
-
681
+
679
682
assert result is True
680
-
683
+
681
684
# Verify the database was queried with the correct parameters
682
- mock_find_unique .assert_called_once_with (
683
- where = {"team_id" : "existing-team-id" }
684
- )
685
+ mock_find_unique .assert_called_once_with (where = {"team_id" : "existing-team-id" })
685
686
686
687
687
688
@pytest .mark .asyncio
688
689
async def test_generate_service_account_key_endpoint_validation ():
689
690
"""
690
- Test that the /key/service-account/generate endpoint properly validates
691
+ Test that the /key/service-account/generate endpoint properly validates
691
692
team_id requirement and team existence in database.
692
693
"""
693
694
from unittest .mock import patch
@@ -705,16 +706,16 @@ async def test_generate_service_account_key_endpoint_validation():
705
706
),
706
707
litellm_changed_by = None ,
707
708
)
708
-
709
+
709
710
assert exc_info .value .status_code == 400
710
711
assert "team_id is required for service account keys" in str (exc_info .value .detail )
711
-
712
- # Test case 2: Team doesn't exist in database
713
- with patch (' litellm.proxy.proxy_server.prisma_client' ) as mock_prisma :
712
+
713
+ # Test case 2: Team doesn't exist in database
714
+ with patch (" litellm.proxy.proxy_server.prisma_client" ) as mock_prisma :
714
715
# Mock team not found
715
716
mock_find_unique = AsyncMock (return_value = None )
716
717
mock_prisma .db .litellm_teamtable .find_unique = mock_find_unique
717
-
718
+
718
719
with pytest .raises (HTTPException ) as exc_info :
719
720
await generate_service_account_key_fn (
720
721
data = GenerateKeyRequest (team_id = "non-existent-team" ),
@@ -723,7 +724,165 @@ async def test_generate_service_account_key_endpoint_validation():
723
724
),
724
725
litellm_changed_by = None ,
725
726
)
726
-
727
+
727
728
assert exc_info .value .status_code == 400
728
729
assert "team_id does not exist in the database" in str (exc_info .value .detail )
729
730
731
+
732
+ @pytest .mark .asyncio
733
+ async def test_unblock_key_supports_both_sk_and_hashed_tokens (monkeypatch ):
734
+ """
735
+ Test that the unblock_key endpoint correctly handles both sk- prefixed tokens
736
+ and hashed tokens by properly converting sk- tokens to hashed format before
737
+ database operations.
738
+ """
739
+ from unittest .mock import AsyncMock , MagicMock
740
+
741
+ from litellm .proxy ._types import BlockKeyRequest
742
+ from litellm .proxy .management_endpoints .key_management_endpoints import unblock_key
743
+
744
+ # Mock dependencies
745
+ mock_prisma_client = AsyncMock ()
746
+ mock_user_api_key_cache = MagicMock ()
747
+ mock_proxy_logging_obj = MagicMock ()
748
+
749
+ # Use a proper 64-character hex hash for testing
750
+ test_hashed_token = (
751
+ "a1b2c3d4e5f6789012345678901234567890123456789012345678901234abcd"
752
+ )
753
+
754
+ # Mock the key record that will be returned from database
755
+ mock_key_record = MagicMock ()
756
+ mock_key_record .token = test_hashed_token
757
+ mock_key_record .blocked = False
758
+ mock_key_record .model_dump_json .return_value = (
759
+ f'{{"token": "{ test_hashed_token } ", "blocked": false}}'
760
+ )
761
+
762
+ # Mock database operations
763
+ mock_prisma_client .db .litellm_verificationtoken .find_unique = AsyncMock (
764
+ return_value = mock_key_record
765
+ )
766
+ mock_prisma_client .db .litellm_verificationtoken .update = AsyncMock (
767
+ return_value = mock_key_record
768
+ )
769
+
770
+ # Mock get_key_object and _cache_key_object functions
771
+ mock_key_object = MagicMock ()
772
+ mock_key_object .blocked = True # Initially blocked
773
+
774
+ # Mock hash_token function
775
+ def mock_hash_token (token ):
776
+ if token == "sk-test123456789" :
777
+ return test_hashed_token
778
+ return token
779
+
780
+ # Apply monkeypatch
781
+ monkeypatch .setattr ("litellm.proxy.proxy_server.prisma_client" , mock_prisma_client )
782
+ monkeypatch .setattr (
783
+ "litellm.proxy.proxy_server.user_api_key_cache" , mock_user_api_key_cache
784
+ )
785
+ monkeypatch .setattr (
786
+ "litellm.proxy.proxy_server.proxy_logging_obj" , mock_proxy_logging_obj
787
+ )
788
+ monkeypatch .setattr ("litellm.proxy.proxy_server.hash_token" , mock_hash_token )
789
+ monkeypatch .setattr (
790
+ "litellm.store_audit_logs" , False
791
+ ) # Disable audit logs for simpler test
792
+
793
+ # Mock get_key_object and _cache_key_object
794
+ async def mock_get_key_object (** kwargs ):
795
+ return mock_key_object
796
+
797
+ async def mock_cache_key_object (** kwargs ):
798
+ pass
799
+
800
+ monkeypatch .setattr (
801
+ "litellm.proxy.management_endpoints.key_management_endpoints.get_key_object" ,
802
+ mock_get_key_object ,
803
+ )
804
+ monkeypatch .setattr (
805
+ "litellm.proxy.management_endpoints.key_management_endpoints._cache_key_object" ,
806
+ mock_cache_key_object ,
807
+ )
808
+
809
+ # Create mock request and user auth
810
+ mock_request = MagicMock ()
811
+ user_api_key_dict = UserAPIKeyAuth (
812
+ user_role = LitellmUserRoles .PROXY_ADMIN , api_key = "sk-admin" , user_id = "admin_user"
813
+ )
814
+
815
+ # Test Case 1: Using sk- prefixed token
816
+ sk_token_request = BlockKeyRequest (key = "sk-test123456789" )
817
+
818
+ result = await unblock_key (
819
+ data = sk_token_request ,
820
+ http_request = mock_request ,
821
+ user_api_key_dict = user_api_key_dict ,
822
+ litellm_changed_by = None ,
823
+ )
824
+
825
+ # Verify that the database update was called with hashed token
826
+ mock_prisma_client .db .litellm_verificationtoken .update .assert_called_with (
827
+ where = {"token" : test_hashed_token }, data = {"blocked" : False }
828
+ )
829
+
830
+ assert result == mock_key_record
831
+ assert mock_key_object .blocked == False # Should be updated to unblocked
832
+
833
+ # Reset mocks for second test
834
+ mock_prisma_client .db .litellm_verificationtoken .update .reset_mock ()
835
+ mock_key_object .blocked = True # Reset to blocked state
836
+
837
+ # Test Case 2: Using already hashed token
838
+ hashed_token_request = BlockKeyRequest (key = test_hashed_token )
839
+
840
+ result = await unblock_key (
841
+ data = hashed_token_request ,
842
+ http_request = mock_request ,
843
+ user_api_key_dict = user_api_key_dict ,
844
+ litellm_changed_by = None ,
845
+ )
846
+
847
+ # Verify that the database update was called with the same hashed token
848
+ mock_prisma_client .db .litellm_verificationtoken .update .assert_called_with (
849
+ where = {"token" : test_hashed_token }, data = {"blocked" : False }
850
+ )
851
+
852
+ assert result == mock_key_record
853
+ assert mock_key_object .blocked == False # Should be updated to unblocked
854
+
855
+
856
+ @pytest .mark .asyncio
857
+ async def test_unblock_key_invalid_key_format (monkeypatch ):
858
+ """
859
+ Test that unblock_key properly validates key format and raises appropriate errors
860
+ for invalid keys.
861
+ """
862
+ from litellm .proxy ._types import BlockKeyRequest
863
+ from litellm .proxy .management_endpoints .key_management_endpoints import unblock_key
864
+ from litellm .proxy .utils import ProxyException
865
+
866
+ # Mock prisma_client to avoid DB connection error
867
+ mock_prisma_client = AsyncMock ()
868
+ monkeypatch .setattr ("litellm.proxy.proxy_server.prisma_client" , mock_prisma_client )
869
+
870
+ # Mock request and user auth
871
+ mock_request = MagicMock ()
872
+ user_api_key_dict = UserAPIKeyAuth (
873
+ user_role = LitellmUserRoles .PROXY_ADMIN , api_key = "sk-admin" , user_id = "admin_user"
874
+ )
875
+
876
+ # Test with invalid key format
877
+ invalid_key_request = BlockKeyRequest (key = "invalid-key-format" )
878
+
879
+ with pytest .raises (ProxyException ) as exc_info :
880
+ await unblock_key (
881
+ data = invalid_key_request ,
882
+ http_request = mock_request ,
883
+ user_api_key_dict = user_api_key_dict ,
884
+ litellm_changed_by = None ,
885
+ )
886
+
887
+ assert exc_info .value .code == "400"
888
+ assert "Invalid key format" in str (exc_info .value .message )
0 commit comments