@@ -114,7 +114,7 @@ def process_sso_jwt_access_token(
114
114
115
115
@router .get ("/sso/key/generate" , tags = ["experimental" ], include_in_schema = False )
116
116
async def google_login (
117
- request : Request , source : Optional [str ] = None , key : Optional [str ] = None
117
+ request : Request , source : Optional [str ] = None , key : Optional [str ] = None , existing_key : Optional [ str ] = None
118
118
): # noqa: PLR0915
119
119
"""
120
120
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
@@ -173,12 +173,14 @@ async def google_login(
173
173
redirect_url = SSOAuthenticationHandler .get_redirect_url_for_sso (
174
174
request = request ,
175
175
sso_callback_route = "sso/callback" ,
176
+ existing_key = existing_key ,
176
177
)
177
178
178
179
# Store CLI key in state for OAuth flow
179
180
cli_state : Optional [str ] = SSOAuthenticationHandler ._get_cli_state (
180
181
source = source ,
181
182
key = key ,
183
+ existing_key = existing_key ,
182
184
)
183
185
184
186
# check if user defined a custom auth sso sign in handler, if yes, use it
@@ -590,8 +592,12 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
590
592
if state and state .startswith (f"{ LITELLM_CLI_SESSION_TOKEN_PREFIX } :" ):
591
593
# Extract the key ID from the state
592
594
key_id = state .split (":" , 1 )[1 ]
593
- verbose_proxy_logger .info (f"CLI SSO callback detected for key: { key_id } " )
594
- return await cli_sso_callback (request , key = key_id )
595
+
596
+ # Get existing_key from query parameters if provided
597
+ existing_key = request .query_params .get ("existing_key" )
598
+
599
+ verbose_proxy_logger .info (f"CLI SSO callback detected for key: { key_id } , existing_key: { existing_key } " )
600
+ return await cli_sso_callback (request , key = key_id , existing_key = existing_key )
595
601
596
602
from litellm .proxy ._types import LiteLLM_JWTAuth
597
603
from litellm .proxy .auth .handle_jwt import JWTHandler
@@ -678,13 +684,59 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
678
684
)
679
685
680
686
681
- async def cli_sso_callback (request : Request , key : Optional [str ] = None ):
682
- """CLI SSO callback - generates the key with pre-specified ID"""
683
- verbose_proxy_logger .info (f"CLI SSO callback for key: { key } " )
687
+ async def _regenerate_cli_key (existing_key : str , new_key : str ) -> None :
688
+ """Regenerate an existing CLI key with a new token"""
689
+ from litellm .proxy ._types import RegenerateKeyRequest , UserAPIKeyAuth
690
+ from litellm .proxy .management_endpoints .key_management_endpoints import (
691
+ regenerate_key_fn ,
692
+ )
693
+
694
+ verbose_proxy_logger .info (f"Regenerating existing CLI key: { existing_key } " )
695
+
696
+ admin_user_dict = UserAPIKeyAuth .get_litellm_cli_user_api_key_auth ()
697
+
698
+ regenerate_request = RegenerateKeyRequest (
699
+ key = existing_key ,
700
+ new_key = new_key ,
701
+ duration = "24hr"
702
+ )
703
+
704
+ await regenerate_key_fn (
705
+ key = existing_key ,
706
+ data = regenerate_request ,
707
+ user_api_key_dict = admin_user_dict
708
+ )
709
+
710
+ verbose_proxy_logger .info (f"Regenerated CLI key: { new_key } " )
711
+
684
712
713
+ async def _create_new_cli_key (key : str ) -> None :
714
+ """Create a new CLI key"""
685
715
from litellm .proxy .management_endpoints .key_management_endpoints import (
686
716
generate_key_helper_fn ,
687
717
)
718
+
719
+ verbose_proxy_logger .info ("Creating new CLI key" )
720
+
721
+ await generate_key_helper_fn (
722
+ request_type = "key" ,
723
+ duration = "24hr" ,
724
+ key_max_budget = litellm .max_ui_session_budget ,
725
+ aliases = {},
726
+ config = {},
727
+ spend = 0 ,
728
+ team_id = "litellm-cli" ,
729
+ table_name = "key" ,
730
+ token = key ,
731
+ )
732
+
733
+ verbose_proxy_logger .info (f"Created new CLI key: { key } " )
734
+
735
+
736
+ async def cli_sso_callback (request : Request , key : Optional [str ] = None , existing_key : Optional [str ] = None ):
737
+ """CLI SSO callback - regenerates existing CLI key or creates new one"""
738
+ verbose_proxy_logger .info (f"CLI SSO callback for key: { key } , existing_key: { existing_key } " )
739
+
688
740
from litellm .proxy .proxy_server import prisma_client
689
741
690
742
if not key or not key .startswith ("sk-" ):
@@ -698,21 +750,11 @@ async def cli_sso_callback(request: Request, key: Optional[str] = None):
698
750
status_code = 500 , detail = CommonProxyErrors .db_not_connected_error .value
699
751
)
700
752
701
- # Generate a simple key for CLI usage with the pre-specified key ID
702
753
try :
703
- await generate_key_helper_fn (
704
- request_type = "key" ,
705
- duration = "24hr" ,
706
- key_max_budget = litellm .max_ui_session_budget ,
707
- aliases = {},
708
- config = {},
709
- spend = 0 ,
710
- team_id = "litellm-cli" ,
711
- table_name = "key" ,
712
- token = key , # Use the pre-specified key ID
713
- )
714
-
715
- verbose_proxy_logger .info (f"Generated CLI key: { key } " )
754
+ if existing_key :
755
+ await _regenerate_cli_key (existing_key , key )
756
+ else :
757
+ await _create_new_cli_key (key )
716
758
717
759
# Return success page
718
760
from fastapi .responses import HTMLResponse
@@ -725,8 +767,8 @@ async def cli_sso_callback(request: Request, key: Optional[str] = None):
725
767
return HTMLResponse (content = html_content , status_code = 200 )
726
768
727
769
except Exception as e :
728
- verbose_proxy_logger .error (f"Error generating CLI key: { e } " )
729
- raise HTTPException (status_code = 500 , detail = f"Failed to generate key: { str (e )} " )
770
+ verbose_proxy_logger .error (f"Error with CLI key: { e } " )
771
+ raise HTTPException (status_code = 500 , detail = f"Failed to process CLI key: { str (e )} " )
730
772
731
773
732
774
@router .get ("/sso/cli/poll/{key_id}" , tags = ["experimental" ], include_in_schema = False )
@@ -993,20 +1035,51 @@ async def get_sso_login_redirect(
993
1035
# or a cryptographicly signed state that we can verify stateless
994
1036
# For simplification we are using a static state, this is not perfect but some
995
1037
# SSO providers do not allow stateless verification
996
- redirect_params = {}
997
- state = os .getenv ("GENERIC_CLIENT_STATE" , None )
998
-
999
- if state :
1000
- redirect_params ["state" ] = state
1001
- elif "okta" in generic_authorization_endpoint :
1002
- redirect_params ["state" ] = (
1003
- uuid .uuid4 ().hex
1004
- ) # set state param for okta - required
1038
+ redirect_params = SSOAuthenticationHandler ._get_generic_sso_redirect_params (
1039
+ state = state ,
1040
+ generic_authorization_endpoint = generic_authorization_endpoint
1041
+ )
1042
+
1005
1043
return await generic_sso .get_login_redirect (** redirect_params ) # type: ignore
1006
1044
raise ValueError (
1007
1045
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
1008
1046
)
1009
1047
1048
+ @staticmethod
1049
+ def _get_generic_sso_redirect_params (
1050
+ state : Optional [str ] = None ,
1051
+ generic_authorization_endpoint : Optional [str ] = None
1052
+ ) -> dict :
1053
+ """
1054
+ Get redirect parameters for Generic SSO with proper state priority handling.
1055
+
1056
+ Priority order:
1057
+ 1. CLI state (if provided)
1058
+ 2. GENERIC_CLIENT_STATE environment variable
1059
+ 3. Generated UUID for Okta (if Okta endpoint detected)
1060
+
1061
+ Args:
1062
+ state: Optional state parameter (e.g., CLI state)
1063
+ generic_authorization_endpoint: Authorization endpoint URL
1064
+
1065
+ Returns:
1066
+ dict: Redirect parameters for SSO login
1067
+ """
1068
+ redirect_params = {}
1069
+
1070
+ if state :
1071
+ # CLI state takes priority
1072
+ # the litellm proxy cli sends the "state" parameter to the proxy server for auth. We should maintain the state parameter for the cli if it is provided
1073
+ redirect_params ["state" ] = state
1074
+ else :
1075
+ generic_client_state = os .getenv ("GENERIC_CLIENT_STATE" , None )
1076
+ if generic_client_state :
1077
+ redirect_params ["state" ] = generic_client_state
1078
+ elif generic_authorization_endpoint and "okta" in generic_authorization_endpoint :
1079
+ redirect_params ["state" ] = uuid .uuid4 ().hex # set state param for okta - required
1080
+
1081
+ return redirect_params
1082
+
1010
1083
@staticmethod
1011
1084
def should_use_sso_handler (
1012
1085
google_client_id : Optional [str ] = None ,
@@ -1025,6 +1098,7 @@ def should_use_sso_handler(
1025
1098
def get_redirect_url_for_sso (
1026
1099
request : Request ,
1027
1100
sso_callback_route : str ,
1101
+ existing_key : Optional [str ] = None ,
1028
1102
) -> str :
1029
1103
"""
1030
1104
Get the redirect URL for SSO
@@ -1036,6 +1110,11 @@ def get_redirect_url_for_sso(
1036
1110
redirect_url += sso_callback_route
1037
1111
else :
1038
1112
redirect_url += "/" + sso_callback_route
1113
+
1114
+ # Append existing_key as query parameter if provided
1115
+ if existing_key :
1116
+ redirect_url += f"?existing_key={ existing_key } "
1117
+
1039
1118
return redirect_url
1040
1119
1041
1120
@staticmethod
@@ -1218,7 +1297,7 @@ def _cast_and_deepcopy_litellm_default_team_params(
1218
1297
return team_request
1219
1298
1220
1299
@staticmethod
1221
- def _get_cli_state (source : Optional [str ], key : Optional [str ]) -> Optional [str ]:
1300
+ def _get_cli_state (source : Optional [str ], key : Optional [str ], existing_key : Optional [ str ] = None ) -> Optional [str ]:
1222
1301
"""
1223
1302
Checks the request 'source' if a cli state token was passed in
1224
1303
@@ -1229,11 +1308,11 @@ def _get_cli_state(source: Optional[str], key: Optional[str]) -> Optional[str]:
1229
1308
LITELLM_CLI_SOURCE_IDENTIFIER ,
1230
1309
)
1231
1310
1232
- return (
1233
- f" { LITELLM_CLI_SESSION_TOKEN_PREFIX } : { key } "
1234
- if source == LITELLM_CLI_SOURCE_IDENTIFIER and key
1235
- else None
1236
- )
1311
+ if source == LITELLM_CLI_SOURCE_IDENTIFIER and key :
1312
+ # Just use the key - existing_key will be passed separately via query params
1313
+ return f" { LITELLM_CLI_SESSION_TOKEN_PREFIX } : { key } "
1314
+ else :
1315
+ return None
1237
1316
1238
1317
@staticmethod
1239
1318
async def get_redirect_response_from_openid ( # noqa: PLR0915
0 commit comments