@@ -90,10 +90,10 @@ def _get_user_in_team(
90
90
def _calculate_key_rotation_time (rotation_interval : str ) -> datetime :
91
91
"""
92
92
Helper function to calculate the next rotation time for a key based on the rotation interval.
93
-
93
+
94
94
Args:
95
95
rotation_interval: String representing the rotation interval (e.g., '30d', '90d', '1h')
96
-
96
+
97
97
Returns:
98
98
datetime: The calculated next rotation time in UTC
99
99
"""
@@ -102,28 +102,34 @@ def _calculate_key_rotation_time(rotation_interval: str) -> datetime:
102
102
return now + timedelta (seconds = interval_seconds )
103
103
104
104
105
- def _set_key_rotation_fields (data : dict , auto_rotate : bool , rotation_interval : Optional [str ]) -> None :
105
+ def _set_key_rotation_fields (
106
+ data : dict , auto_rotate : bool , rotation_interval : Optional [str ]
107
+ ) -> None :
106
108
"""
107
109
Helper function to set rotation fields in key data if auto_rotate is enabled.
108
-
110
+
109
111
Args:
110
112
data: Dictionary to update with rotation fields
111
113
auto_rotate: Whether auto rotation is enabled
112
114
rotation_interval: The rotation interval string (required if auto_rotate is True)
113
115
"""
114
116
if auto_rotate and rotation_interval :
115
- data .update ({
116
- "auto_rotate" : auto_rotate ,
117
- "rotation_interval" : rotation_interval ,
118
- "key_rotation_at" : _calculate_key_rotation_time (rotation_interval )
119
- })
117
+ data .update (
118
+ {
119
+ "auto_rotate" : auto_rotate ,
120
+ "rotation_interval" : rotation_interval ,
121
+ "key_rotation_at" : _calculate_key_rotation_time (rotation_interval ),
122
+ }
123
+ )
120
124
121
125
122
126
def _is_allowed_to_make_key_request (
123
- user_api_key_dict : UserAPIKeyAuth , user_id : Optional [str ], team_id : Optional [str ]
127
+ user_api_key_dict : UserAPIKeyAuth ,
128
+ user_id : Optional [str ],
129
+ team_id : Optional [str ],
124
130
) -> bool :
125
131
"""
126
- Assert user only creates keys for themselves
132
+ Assert user only creates/updates keys for themselves
127
133
128
134
Relevant issue: https://github.com/BerriAI/litellm/issues/7336
129
135
"""
@@ -332,14 +338,15 @@ def common_key_access_checks(
332
338
data : Union [GenerateKeyRequest , UpdateKeyRequest ],
333
339
llm_router : Optional [Router ],
334
340
premium_user : bool ,
341
+ user_id : Optional [str ] = None ,
335
342
) -> Literal [True ]:
336
343
"""
337
344
Check if user is allowed to make a key request, for this key
338
345
"""
339
346
try :
340
347
_is_allowed_to_make_key_request (
341
348
user_api_key_dict = user_api_key_dict ,
342
- user_id = data .user_id ,
349
+ user_id = user_id or data .user_id ,
343
350
team_id = data .team_id ,
344
351
)
345
352
except AssertionError as e :
@@ -1136,13 +1143,6 @@ async def update_key_fn(
1136
1143
if prisma_client is None :
1137
1144
raise Exception ("Not connected to DB!" )
1138
1145
1139
- common_key_access_checks (
1140
- user_api_key_dict = user_api_key_dict ,
1141
- data = data ,
1142
- llm_router = llm_router ,
1143
- premium_user = premium_user ,
1144
- )
1145
-
1146
1146
existing_key_row = await prisma_client .get_data (
1147
1147
token = data .key , table_name = "key" , query_type = "find_unique"
1148
1148
)
@@ -1153,6 +1153,14 @@ async def update_key_fn(
1153
1153
detail = {"error" : f"Team not found, passed team_id={ data .team_id } " },
1154
1154
)
1155
1155
1156
+ common_key_access_checks (
1157
+ user_api_key_dict = user_api_key_dict ,
1158
+ data = data ,
1159
+ user_id = existing_key_row .user_id ,
1160
+ llm_router = llm_router ,
1161
+ premium_user = premium_user ,
1162
+ )
1163
+
1156
1164
# check if user has permission to update key
1157
1165
await TeamMemberPermissionChecks .can_team_member_execute_key_management_endpoint (
1158
1166
user_api_key_dict = user_api_key_dict ,
@@ -1198,9 +1206,9 @@ async def update_key_fn(
1198
1206
1199
1207
# Handle rotation fields if auto_rotate is being enabled
1200
1208
_set_key_rotation_fields (
1201
- non_default_values ,
1202
- non_default_values .get ("auto_rotate" , False ),
1203
- non_default_values .get ("rotation_interval" )
1209
+ non_default_values ,
1210
+ non_default_values .get ("auto_rotate" , False ),
1211
+ non_default_values .get ("rotation_interval" ),
1204
1212
)
1205
1213
1206
1214
_data = {** non_default_values , "token" : key }
@@ -1602,8 +1610,6 @@ def _check_model_access_group(
1602
1610
return True
1603
1611
1604
1612
1605
-
1606
-
1607
1613
async def generate_key_helper_fn ( # noqa: PLR0915
1608
1614
request_type : Literal [
1609
1615
"user" , "key"
@@ -1766,12 +1772,12 @@ async def generate_key_helper_fn( # noqa: PLR0915
1766
1772
"allowed_routes" : allowed_routes or [],
1767
1773
"object_permission_id" : object_permission_id ,
1768
1774
}
1769
-
1775
+
1770
1776
# Add rotation fields if auto_rotate is enabled
1771
1777
_set_key_rotation_fields (
1772
1778
data = key_data ,
1773
1779
auto_rotate = auto_rotate or False ,
1774
- rotation_interval = rotation_interval
1780
+ rotation_interval = rotation_interval ,
1775
1781
)
1776
1782
1777
1783
if (
0 commit comments