Skip to content

Commit 75e3f4e

Browse files
Merge pull request #15120 from BerriAI/litellm_dev_10_01_2025_p2
(feat) Support 'guaranteed_throughput' when setting limits on keys belonging to a team
2 parents b188f76 + 691a1fa commit 75e3f4e

File tree

6 files changed

+938
-86
lines changed

6 files changed

+938
-86
lines changed

litellm/proxy/_types.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,7 @@ class GenerateRequestBase(LiteLLMPydanticObjectBase):
731731
metadata: Optional[dict] = {}
732732
tpm_limit: Optional[int] = None
733733
rpm_limit: Optional[int] = None
734+
734735
budget_duration: Optional[str] = None
735736
allowed_cache_controls: Optional[list] = []
736737
config: Optional[dict] = {}
@@ -755,6 +756,12 @@ class KeyRequestBase(GenerateRequestBase):
755756
tags: Optional[List[str]] = None
756757
enforced_params: Optional[List[str]] = None
757758
allowed_routes: Optional[list] = []
759+
rpm_limit_type: Optional[
760+
Literal["guaranteed_throughput", "best_effort_throughput"]
761+
] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating rpm
762+
tpm_limit_type: Optional[
763+
Literal["guaranteed_throughput", "best_effort_throughput"]
764+
] = None # raise an error if 'guaranteed_throughput' is set and we're overallocating tpm
758765

759766

760767
class LiteLLMKeyType(str, enum.Enum):
@@ -3056,6 +3063,8 @@ class PassThroughEndpointLoggingTypedDict(TypedDict):
30563063
LiteLLM_ManagementEndpoint_MetadataFields = [
30573064
"model_rpm_limit",
30583065
"model_tpm_limit",
3066+
"rpm_limit_type",
3067+
"tpm_limit_type",
30593068
"guardrails",
30603069
"tags",
30613070
"enforced_params",

litellm/proxy/management_endpoints/key_management_endpoints.py

Lines changed: 215 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from litellm.constants import LENGTH_OF_LITELLM_GENERATED_KEY, UI_SESSION_TOKEN_TEAM_ID
2828
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
2929
from litellm.proxy._types import *
30+
from litellm.proxy._types import LiteLLM_VerificationToken
3031
from litellm.proxy.auth.auth_checks import (
3132
_cache_key_object,
3233
_delete_cache_key_object,
@@ -549,6 +550,15 @@ async def _common_key_generation_helper( # noqa: PLR0915
549550
value=getattr(data, field),
550551
)
551552

553+
for field in LiteLLM_ManagementEndpoint_MetadataFields:
554+
if getattr(data, field, None) is not None:
555+
_set_object_metadata_field(
556+
object_data=data,
557+
field_name=field,
558+
value=getattr(data, field),
559+
)
560+
delattr(data, field)
561+
552562
data_json = data.model_dump(exclude_unset=True, exclude_none=True) # type: ignore
553563

554564
data_json = handle_key_type(data, data_json)
@@ -627,6 +637,153 @@ async def _common_key_generation_helper( # noqa: PLR0915
627637
return response
628638

629639

640+
def check_team_key_model_specific_limits(
641+
keys: List[LiteLLM_VerificationToken],
642+
team_table: LiteLLM_TeamTableCachedObj,
643+
data: Union[GenerateKeyRequest, UpdateKeyRequest],
644+
) -> None:
645+
"""
646+
Check if the team key is allocating model specific limits. If so, raise an error if we're overallocating.
647+
"""
648+
if data.model_rpm_limit is None and data.model_tpm_limit is None:
649+
return
650+
# get total model specific tpm/rpm limit
651+
model_specific_rpm_limit: Dict[str, int] = {}
652+
model_specific_tpm_limit: Dict[str, int] = {}
653+
654+
for key in keys:
655+
if key.metadata.get("model_rpm_limit", None) is not None:
656+
for model, rpm_limit in key.metadata.get("model_rpm_limit", {}).items():
657+
model_specific_rpm_limit[model] = (
658+
model_specific_rpm_limit.get(model, 0) + rpm_limit
659+
)
660+
if key.metadata.get("model_tpm_limit", None) is not None:
661+
for model, tpm_limit in key.metadata.get("model_tpm_limit", {}).items():
662+
model_specific_tpm_limit[model] = (
663+
model_specific_tpm_limit.get(model, 0) + tpm_limit
664+
)
665+
if data.model_rpm_limit is not None:
666+
for model, rpm_limit in data.model_rpm_limit.items():
667+
if (
668+
model_specific_rpm_limit.get(model, 0) + rpm_limit
669+
> team_table.rpm_limit
670+
):
671+
raise HTTPException(
672+
status_code=400,
673+
detail=f"Allocated RPM limit={model_specific_rpm_limit.get(model, 0)} + Key RPM limit={rpm_limit} is greater than team RPM limit={team_table.rpm_limit}",
674+
)
675+
elif team_table.metadata and team_table.metadata.get("model_rpm_limit"):
676+
team_model_specific_rpm_limit_dict = team_table.metadata.get(
677+
"model_rpm_limit", {}
678+
)
679+
team_model_specific_rpm_limit = team_model_specific_rpm_limit_dict.get(
680+
model
681+
)
682+
if (
683+
model_specific_rpm_limit.get(model, 0) + rpm_limit
684+
> team_model_specific_rpm_limit
685+
):
686+
raise HTTPException(
687+
status_code=400,
688+
detail=f"Allocated RPM limit={model_specific_rpm_limit.get(model, 0)} + Key RPM limit={rpm_limit} is greater than team RPM limit={team_model_specific_rpm_limit.get(model, 0)}",
689+
)
690+
if data.model_tpm_limit is not None:
691+
for model, tpm_limit in data.model_tpm_limit.items():
692+
if (
693+
team_table.tpm_limit is not None
694+
and model_specific_tpm_limit.get(model, 0) + tpm_limit
695+
> team_table.tpm_limit
696+
):
697+
raise HTTPException(
698+
status_code=400,
699+
detail=f"Allocated TPM limit={model_specific_tpm_limit.get(model, 0)} + Key TPM limit={tpm_limit} is greater than team TPM limit={team_table.tpm_limit}",
700+
)
701+
elif team_table.metadata and team_table.metadata.get("model_tpm_limit"):
702+
team_model_specific_tpm_limit_dict = team_table.metadata.get(
703+
"model_tpm_limit", {}
704+
)
705+
team_model_specific_tpm_limit = team_model_specific_tpm_limit_dict.get(
706+
model
707+
)
708+
if (
709+
team_model_specific_tpm_limit
710+
and model_specific_tpm_limit.get(model, 0) + tpm_limit
711+
> team_model_specific_tpm_limit
712+
):
713+
raise HTTPException(
714+
status_code=400,
715+
detail=f"Allocated TPM limit={model_specific_tpm_limit.get(model, 0)} + Key TPM limit={tpm_limit} is greater than team TPM limit={team_model_specific_tpm_limit}",
716+
)
717+
718+
719+
def check_team_key_rpm_tpm_limits(
720+
keys: List[LiteLLM_VerificationToken],
721+
team_table: LiteLLM_TeamTableCachedObj,
722+
data: Union[GenerateKeyRequest, UpdateKeyRequest],
723+
) -> None:
724+
"""
725+
Check if the team key is allocating rpm/tpm limits. If so, raise an error if we're overallocating.
726+
"""
727+
if keys is not None and len(keys) > 0:
728+
allocated_tpm = sum(key.tpm_limit for key in keys if key.tpm_limit is not None)
729+
allocated_rpm = sum(key.rpm_limit for key in keys if key.rpm_limit is not None)
730+
else:
731+
allocated_tpm = 0
732+
allocated_rpm = 0
733+
if (
734+
data.tpm_limit is not None
735+
and team_table.tpm_limit is not None
736+
and data.tpm_limit + allocated_tpm > team_table.tpm_limit
737+
):
738+
raise HTTPException(
739+
status_code=400,
740+
detail=f"Allocated TPM limit={allocated_tpm} + Key TPM limit={data.tpm_limit} is greater than team TPM limit={team_table.tpm_limit}",
741+
)
742+
if (
743+
data.rpm_limit is not None
744+
and team_table.rpm_limit is not None
745+
and data.rpm_limit + allocated_rpm > team_table.rpm_limit
746+
):
747+
raise HTTPException(
748+
status_code=400,
749+
detail=f"Allocated RPM limit={allocated_rpm} + Key RPM limit={data.rpm_limit} is greater than team RPM limit={team_table.rpm_limit}",
750+
)
751+
752+
753+
async def _check_team_key_limits(
754+
team_table: LiteLLM_TeamTableCachedObj,
755+
data: Union[GenerateKeyRequest, UpdateKeyRequest],
756+
prisma_client: PrismaClient,
757+
) -> None:
758+
"""
759+
Check if the team key is allocating guaranteed throughput limits. If so, raise an error if we're overallocating.
760+
761+
Only runs check if tpm_limit_type or rpm_limit_type is "guaranteed_throughput"
762+
"""
763+
if (
764+
data.tpm_limit_type != "guaranteed_throughput"
765+
and data.rpm_limit_type != "guaranteed_throughput"
766+
):
767+
return
768+
# get all team keys
769+
# calculate allocated tpm/rpm limit
770+
# check if specified tpm/rpm limit is greater than allocated tpm/rpm limit
771+
772+
keys = await prisma_client.db.litellm_verificationtoken.find_many(
773+
where={"team_id": team_table.team_id},
774+
)
775+
check_team_key_model_specific_limits(
776+
keys=keys,
777+
team_table=team_table,
778+
data=data,
779+
)
780+
check_team_key_rpm_tpm_limits(
781+
keys=keys,
782+
team_table=team_table,
783+
data=data,
784+
)
785+
786+
630787
@router.post(
631788
"/key/generate",
632789
tags=["key management"],
@@ -668,6 +825,8 @@ async def generate_key_fn(
668825
- model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget.
669826
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
670827
- model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit.
828+
- tpm_limit_type: Optional[str] - Type of tpm limit. Options: "best_effort_throughput" (no error if we're overallocating tpm), "guaranteed_throughput" (raise an error if we're overallocating tpm). Defaults to "best_effort_throughput".
829+
- rpm_limit_type: Optional[str] - Type of rpm limit. Options: "best_effort_throughput" (no error if we're overallocating rpm), "guaranteed_throughput" (raise an error if we're overallocating rpm). Defaults to "best_effort_throughput".
671830
- allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request
672831
- blocked: Optional[bool] - Whether the key is blocked.
673832
- rpm_limit: Optional[int] - Specify rpm limit for a given key (Requests per minute)
@@ -703,12 +862,19 @@ async def generate_key_fn(
703862
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
704863
"""
705864
try:
865+
from litellm.proxy._types import CommonProxyErrors
706866
from litellm.proxy.proxy_server import (
707867
prisma_client,
708868
user_api_key_cache,
709869
user_custom_key_generate,
710870
)
711871

872+
if prisma_client is None:
873+
raise HTTPException(
874+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
875+
detail={"error": CommonProxyErrors.db_not_connected_error.value},
876+
)
877+
712878
verbose_proxy_logger.debug("entered /key/generate")
713879

714880
if user_custom_key_generate is not None:
@@ -736,7 +902,6 @@ async def generate_key_fn(
736902
verbose_proxy_logger.debug(
737903
f"Error getting team object in `/key/generate`: {e}"
738904
)
739-
team_table = None
740905

741906
key_generation_check(
742907
team_table=team_table,
@@ -745,12 +910,21 @@ async def generate_key_fn(
745910
route=KeyManagementRoutes.KEY_GENERATE,
746911
)
747912

913+
if team_table is not None:
914+
await _check_team_key_limits(
915+
team_table=team_table,
916+
data=data,
917+
prisma_client=prisma_client,
918+
)
919+
748920
return await _common_key_generation_helper(
749921
data=data,
750922
user_api_key_dict=user_api_key_dict,
751923
litellm_changed_by=litellm_changed_by,
752924
team_table=team_table,
753925
)
926+
except HTTPException as e:
927+
raise e
754928
except Exception as e:
755929
verbose_proxy_logger.exception(
756930
"litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format(
@@ -804,6 +978,8 @@ async def generate_service_account_key_fn(
804978
- model_max_budget: Optional[Dict[str, BudgetConfig]] - Model-specific budgets {"gpt-4": {"budget_limit": 0.0005, "time_period": "30d"}}}. IF null or {} then no model specific budget.
805979
- model_rpm_limit: Optional[dict] - key-specific model rpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific rpm limit.
806980
- model_tpm_limit: Optional[dict] - key-specific model tpm limit. Example - {"text-davinci-002": 1000, "gpt-3.5-turbo": 1000}. IF null or {} then no model specific tpm limit.
981+
- tpm_limit_type: Optional[str] - TPM rate limit type - "best_effort_throughput" or "guaranteed_throughput"
982+
- rpm_limit_type: Optional[str] - RPM rate limit type - "best_effort_throughput" or "guaranteed_throughput"
807983
- allowed_cache_controls: Optional[list] - List of allowed cache control values. Example - ["no-cache", "no-store"]. See all values - https://docs.litellm.ai/docs/proxy/caching#turn-on--off-caching-per-request
808984
- blocked: Optional[bool] - Whether the key is blocked.
809985
- rpm_limit: Optional[int] - Specify rpm limit for a given key (Requests per minute)
@@ -832,12 +1008,19 @@ async def generate_service_account_key_fn(
8321008
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
8331009
8341010
"""
1011+
from litellm.proxy._types import CommonProxyErrors
8351012
from litellm.proxy.proxy_server import (
8361013
prisma_client,
8371014
user_api_key_cache,
8381015
user_custom_key_generate,
8391016
)
8401017

1018+
if prisma_client is None:
1019+
raise HTTPException(
1020+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
1021+
detail={"error": CommonProxyErrors.db_not_connected_error.value},
1022+
)
1023+
8411024
await validate_team_id_used_in_service_account_request(
8421025
team_id=data.team_id,
8431026
prisma_client=prisma_client,
@@ -870,6 +1053,13 @@ async def generate_service_account_key_fn(
8701053
)
8711054
team_table = None
8721055

1056+
if team_table is not None:
1057+
await _check_team_key_limits(
1058+
team_table=team_table,
1059+
data=data,
1060+
prisma_client=prisma_client,
1061+
)
1062+
8731063
key_generation_check(
8741064
team_table=team_table,
8751065
user_api_key_dict=user_api_key_dict,
@@ -1096,6 +1286,8 @@ async def update_key_fn(
10961286
- rpm_limit: Optional[int] - Requests per minute limit
10971287
- model_rpm_limit: Optional[dict] - Model-specific RPM limits {"gpt-4": 100, "claude-v1": 200}
10981288
- model_tpm_limit: Optional[dict] - Model-specific TPM limits {"gpt-4": 100000, "claude-v1": 200000}
1289+
- tpm_limit_type: Optional[str] - TPM rate limit type - "best_effort_throughput" or "guaranteed_throughput"
1290+
- rpm_limit_type: Optional[str] - RPM rate limit type - "best_effort_throughput" or "guaranteed_throughput"
10991291
- allowed_cache_controls: Optional[list] - List of allowed cache control values
11001292
- duration: Optional[str] - Key validity duration ("30d", "1h", etc.)
11011293
- permissions: Optional[dict] - Key-specific permissions
@@ -1170,21 +1362,40 @@ async def update_key_fn(
11701362
user_api_key_cache=user_api_key_cache,
11711363
)
11721364

1173-
# if team change - check if this is possible
1174-
if is_different_team(data=data, existing_key_row=existing_key_row):
1365+
# Only check team limits if key has a team_id
1366+
team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
1367+
if data.team_id is not None:
11751368
team_obj = await get_team_object(
1176-
team_id=cast(str, data.team_id),
1369+
team_id=data.team_id,
11771370
prisma_client=prisma_client,
11781371
user_api_key_cache=user_api_key_cache,
11791372
check_db_only=True,
11801373
)
1374+
1375+
if team_obj is not None:
1376+
await _check_team_key_limits(
1377+
team_table=team_obj,
1378+
data=data,
1379+
prisma_client=prisma_client,
1380+
)
1381+
1382+
# if team change - check if this is possible
1383+
if is_different_team(data=data, existing_key_row=existing_key_row):
11811384
if llm_router is None:
11821385
raise HTTPException(
11831386
status_code=400,
11841387
detail={
11851388
"error": "LLM router not found. Please set it up by passing in a valid config.yaml or adding models via the UI."
11861389
},
11871390
)
1391+
# team_obj should be set since is_different_team() returns True only when data.team_id is not None
1392+
if team_obj is None:
1393+
raise HTTPException(
1394+
status_code=500,
1395+
detail={
1396+
"error": "Team object not found for team change validation"
1397+
},
1398+
)
11881399
validate_key_team_change(
11891400
key=existing_key_row,
11901401
team=team_obj,

0 commit comments

Comments
 (0)