@@ -850,17 +850,24 @@ async def test_generate_service_account_key_endpoint_validation():
850
850
)
851
851
852
852
# Test case 1: Missing team_id
853
- with pytest .raises (HTTPException ) as exc_info :
854
- await generate_service_account_key_fn (
855
- data = GenerateKeyRequest (team_id = None ),
856
- user_api_key_dict = UserAPIKeyAuth (
857
- user_role = LitellmUserRoles .PROXY_ADMIN , api_key = "sk-1"
858
- ),
859
- litellm_changed_by = None ,
860
- )
853
+ with patch ("litellm.proxy.proxy_server.prisma_client" ) as mock_prisma :
854
+ # Mock prisma_client to be not None so we can reach team_id validation
855
+ mock_prisma_instance = AsyncMock ()
856
+ mock_prisma .return_value = mock_prisma_instance
861
857
862
- assert exc_info .value .status_code == 400
863
- assert "team_id is required for service account keys" in str (exc_info .value .detail )
858
+ with pytest .raises (HTTPException ) as exc_info :
859
+ await generate_service_account_key_fn (
860
+ data = GenerateKeyRequest (team_id = None ),
861
+ user_api_key_dict = UserAPIKeyAuth (
862
+ user_role = LitellmUserRoles .PROXY_ADMIN , api_key = "sk-1"
863
+ ),
864
+ litellm_changed_by = None ,
865
+ )
866
+
867
+ assert exc_info .value .status_code == 400
868
+ assert "team_id is required for service account keys" in str (
869
+ exc_info .value .detail
870
+ )
864
871
865
872
# Test case 2: Team doesn't exist in database
866
873
with patch ("litellm.proxy.proxy_server.prisma_client" ) as mock_prisma :
@@ -1257,6 +1264,8 @@ async def test_check_team_key_limits_no_existing_keys():
1257
1264
data = GenerateKeyRequest (
1258
1265
tpm_limit = 5000 ,
1259
1266
rpm_limit = 500 ,
1267
+ tpm_limit_type = "guaranteed_throughput" ,
1268
+ rpm_limit_type = "guaranteed_throughput" ,
1260
1269
)
1261
1270
1262
1271
# Should not raise any exception
@@ -1333,10 +1342,12 @@ async def test_check_team_key_limits_tpm_overallocation():
1333
1342
existing_key1 = MagicMock ()
1334
1343
existing_key1 .tpm_limit = 6000
1335
1344
existing_key1 .rpm_limit = 100
1345
+ existing_key1 .metadata = {}
1336
1346
1337
1347
existing_key2 = MagicMock ()
1338
1348
existing_key2 .tpm_limit = 3000
1339
1349
existing_key2 .rpm_limit = 200
1350
+ existing_key2 .metadata = {}
1340
1351
1341
1352
mock_prisma_client = AsyncMock ()
1342
1353
mock_prisma_client .db .litellm_verificationtoken .find_many = AsyncMock (
@@ -1360,6 +1371,7 @@ async def test_check_team_key_limits_tpm_overallocation():
1360
1371
data = GenerateKeyRequest (
1361
1372
tpm_limit = 2000 ,
1362
1373
rpm_limit = 100 ,
1374
+ tpm_limit_type = "guaranteed_throughput" ,
1363
1375
)
1364
1376
1365
1377
# Should raise HTTPException for TPM overallocation
@@ -1387,10 +1399,12 @@ async def test_check_team_key_limits_rpm_overallocation():
1387
1399
existing_key1 = MagicMock ()
1388
1400
existing_key1 .tpm_limit = 1000
1389
1401
existing_key1 .rpm_limit = 600
1402
+ existing_key1 .metadata = {}
1390
1403
1391
1404
existing_key2 = MagicMock ()
1392
1405
existing_key2 .tpm_limit = 2000
1393
1406
existing_key2 .rpm_limit = 300
1407
+ existing_key2 .metadata = {}
1394
1408
1395
1409
mock_prisma_client = AsyncMock ()
1396
1410
mock_prisma_client .db .litellm_verificationtoken .find_many = AsyncMock (
@@ -1414,6 +1428,7 @@ async def test_check_team_key_limits_rpm_overallocation():
1414
1428
data = GenerateKeyRequest (
1415
1429
tpm_limit = 1000 ,
1416
1430
rpm_limit = 200 ,
1431
+ rpm_limit_type = "guaranteed_throughput" ,
1417
1432
)
1418
1433
1419
1434
# Should raise HTTPException for RPM overallocation
0 commit comments