Skip to content

Commit 814e49d

Browse files
test: update tests
1 parent f500a86 commit 814e49d

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

tests/test_litellm/proxy/management_endpoints/test_key_management_endpoints.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -850,17 +850,24 @@ async def test_generate_service_account_key_endpoint_validation():
850850
)
851851

852852
# Test case 1: Missing team_id
853-
with pytest.raises(HTTPException) as exc_info:
854-
await generate_service_account_key_fn(
855-
data=GenerateKeyRequest(team_id=None),
856-
user_api_key_dict=UserAPIKeyAuth(
857-
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1"
858-
),
859-
litellm_changed_by=None,
860-
)
853+
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
854+
# Mock prisma_client to be not None so we can reach team_id validation
855+
mock_prisma_instance = AsyncMock()
856+
mock_prisma.return_value = mock_prisma_instance
861857

862-
assert exc_info.value.status_code == 400
863-
assert "team_id is required for service account keys" in str(exc_info.value.detail)
858+
with pytest.raises(HTTPException) as exc_info:
859+
await generate_service_account_key_fn(
860+
data=GenerateKeyRequest(team_id=None),
861+
user_api_key_dict=UserAPIKeyAuth(
862+
user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-1"
863+
),
864+
litellm_changed_by=None,
865+
)
866+
867+
assert exc_info.value.status_code == 400
868+
assert "team_id is required for service account keys" in str(
869+
exc_info.value.detail
870+
)
864871

865872
# Test case 2: Team doesn't exist in database
866873
with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma:
@@ -1257,6 +1264,8 @@ async def test_check_team_key_limits_no_existing_keys():
12571264
data = GenerateKeyRequest(
12581265
tpm_limit=5000,
12591266
rpm_limit=500,
1267+
tpm_limit_type="guaranteed_throughput",
1268+
rpm_limit_type="guaranteed_throughput",
12601269
)
12611270

12621271
# Should not raise any exception
@@ -1333,10 +1342,12 @@ async def test_check_team_key_limits_tpm_overallocation():
13331342
existing_key1 = MagicMock()
13341343
existing_key1.tpm_limit = 6000
13351344
existing_key1.rpm_limit = 100
1345+
existing_key1.metadata = {}
13361346

13371347
existing_key2 = MagicMock()
13381348
existing_key2.tpm_limit = 3000
13391349
existing_key2.rpm_limit = 200
1350+
existing_key2.metadata = {}
13401351

13411352
mock_prisma_client = AsyncMock()
13421353
mock_prisma_client.db.litellm_verificationtoken.find_many = AsyncMock(
@@ -1360,6 +1371,7 @@ async def test_check_team_key_limits_tpm_overallocation():
13601371
data = GenerateKeyRequest(
13611372
tpm_limit=2000,
13621373
rpm_limit=100,
1374+
tpm_limit_type="guaranteed_throughput",
13631375
)
13641376

13651377
# Should raise HTTPException for TPM overallocation
@@ -1387,10 +1399,12 @@ async def test_check_team_key_limits_rpm_overallocation():
13871399
existing_key1 = MagicMock()
13881400
existing_key1.tpm_limit = 1000
13891401
existing_key1.rpm_limit = 600
1402+
existing_key1.metadata = {}
13901403

13911404
existing_key2 = MagicMock()
13921405
existing_key2.tpm_limit = 2000
13931406
existing_key2.rpm_limit = 300
1407+
existing_key2.metadata = {}
13941408

13951409
mock_prisma_client = AsyncMock()
13961410
mock_prisma_client.db.litellm_verificationtoken.find_many = AsyncMock(
@@ -1414,6 +1428,7 @@ async def test_check_team_key_limits_rpm_overallocation():
14141428
data = GenerateKeyRequest(
14151429
tpm_limit=1000,
14161430
rpm_limit=200,
1431+
rpm_limit_type="guaranteed_throughput",
14171432
)
14181433

14191434
# Should raise HTTPException for RPM overallocation

0 commit comments

Comments
 (0)