Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 43 additions & 41 deletions litellm/proxy/management_endpoints/key_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,10 @@ def _get_user_in_team(
def _calculate_key_rotation_time(rotation_interval: str) -> datetime:
"""
Helper function to calculate the next rotation time for a key based on the rotation interval.

Args:
rotation_interval: String representing the rotation interval (e.g., '30d', '90d', '1h')

Returns:
datetime: The calculated next rotation time in UTC
"""
Expand All @@ -102,21 +102,25 @@ def _calculate_key_rotation_time(rotation_interval: str) -> datetime:
return now + timedelta(seconds=interval_seconds)


def _set_key_rotation_fields(data: dict, auto_rotate: bool, rotation_interval: Optional[str]) -> None:
def _set_key_rotation_fields(
data: dict, auto_rotate: bool, rotation_interval: Optional[str]
) -> None:
"""
Helper function to set rotation fields in key data if auto_rotate is enabled.

Args:
data: Dictionary to update with rotation fields
auto_rotate: Whether auto rotation is enabled
rotation_interval: The rotation interval string (required if auto_rotate is True)
"""
if auto_rotate and rotation_interval:
data.update({
"auto_rotate": auto_rotate,
"rotation_interval": rotation_interval,
"key_rotation_at": _calculate_key_rotation_time(rotation_interval)
})
data.update(
{
"auto_rotate": auto_rotate,
"rotation_interval": rotation_interval,
"key_rotation_at": _calculate_key_rotation_time(rotation_interval),
}
)


def _is_allowed_to_make_key_request(
Expand Down Expand Up @@ -598,9 +602,9 @@ async def _common_key_generation_helper( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)

response["soft_budget"] = (
data.soft_budget
) # include the user-input soft budget in the response
response[
"soft_budget"
] = data.soft_budget # include the user-input soft budget in the response

response = GenerateKeyResponse(**response)

Expand Down Expand Up @@ -934,9 +938,9 @@ async def _set_object_permission(
data=data_json["object_permission"],
)
)
data_json["object_permission_id"] = (
created_object_permission.object_permission_id
)
data_json[
"object_permission_id"
] = created_object_permission.object_permission_id
# delete the object_permission from the data_json
data_json.pop("object_permission")
return data_json
Expand All @@ -946,7 +950,6 @@ async def prepare_key_update_data(
data: Union[UpdateKeyRequest, RegenerateKeyRequest],
existing_key_row: LiteLLM_VerificationToken,
):

data_json: dict = data.model_dump(exclude_unset=True)
data_json.pop("key", None)
data_json.pop("new_key", None)
Expand Down Expand Up @@ -1198,9 +1201,9 @@ async def update_key_fn(

# Handle rotation fields if auto_rotate is being enabled
_set_key_rotation_fields(
non_default_values,
non_default_values.get("auto_rotate", False),
non_default_values.get("rotation_interval")
non_default_values,
non_default_values.get("auto_rotate", False),
non_default_values.get("rotation_interval"),
)

_data = {**non_default_values, "token": key}
Expand Down Expand Up @@ -1602,8 +1605,6 @@ def _check_model_access_group(
return True




async def generate_key_helper_fn( # noqa: PLR0915
request_type: Literal[
"user", "key"
Expand Down Expand Up @@ -1676,7 +1677,8 @@ async def generate_key_helper_fn( # noqa: PLR0915
if duration is None: # allow tokens that never expire
expires = None
else:
expires = get_budget_reset_time(budget_duration=duration)
duration_s = duration_in_seconds(duration=duration)
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)

if key_budget_duration is None: # one-time budget
key_reset_at = None
Expand Down Expand Up @@ -1766,12 +1768,12 @@ async def generate_key_helper_fn( # noqa: PLR0915
"allowed_routes": allowed_routes or [],
"object_permission_id": object_permission_id,
}

# Add rotation fields if auto_rotate is enabled
_set_key_rotation_fields(
data=key_data,
auto_rotate=auto_rotate or False,
rotation_interval=rotation_interval
rotation_interval=rotation_interval,
)

if (
Expand Down Expand Up @@ -1968,10 +1970,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
_keys_being_deleted: List[
LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)

if len(_keys_being_deleted) == 0:
Expand Down Expand Up @@ -2079,9 +2081,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config

try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
models: Optional[
List
] = await prisma_client.db.litellm_proxymodeltable.find_many()
except Exception:
models = None
# 2. process model table
Expand Down Expand Up @@ -2392,11 +2394,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
complete_user_info_db_obj: Optional[
BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)

if complete_user_info_db_obj is None:
Expand Down Expand Up @@ -2482,10 +2484,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
teams: Optional[
List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
if teams is None:
return []
Expand Down Expand Up @@ -3049,7 +3051,7 @@ async def key_health(
Checks:
- If key based logging is configured correctly - sends a test log

Usage
Usage

Pass the key in the request header

Expand Down
Loading
Loading