Skip to content

Commit fa20abf

Browse files
authored
[Feat] Proxy CLI Auth - Allow re-using cli auth token (#14780)
* fix: cli auth with SSO okta * fix: add LITTELM_CLI_SERVICE_ACCOUNT_NAME * fix: get_litellm_cli_user_api_key_auth * use existing_key CLI * fix: use existing key * test auth commands * test_cli_sso_callback_regenerate_vs_create_flow
1 parent 52a56bd commit fa20abf

File tree

6 files changed

+392
-37
lines changed

6 files changed

+392
-37
lines changed

litellm/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@
947947
os.getenv("HEALTH_CHECK_TIMEOUT_SECONDS", 60)
948948
) # 60 seconds
949949
LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME = "litellm-internal-health-check"
950+
LITTELM_CLI_SERVICE_ACCOUNT_NAME = "litellm-cli"
950951

951952
UI_SESSION_TOKEN_TEAM_ID = "litellm-dashboard"
952953
LITELLM_PROXY_ADMIN_NAME = "default_user_id"

litellm/proxy/_types.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,6 +1915,22 @@ def get_litellm_internal_health_check_user_api_key_auth(cls) -> "UserAPIKeyAuth"
19151915
key_alias=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME,
19161916
team_alias=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME,
19171917
)
1918+
1919+
@classmethod
1920+
def get_litellm_cli_user_api_key_auth(cls) -> "UserAPIKeyAuth":
1921+
"""
1922+
Returns a `UserAPIKeyAuth` object for the litellm internal health check service account.
1923+
1924+
This is used to track number of requests/spend for health check calls.
1925+
"""
1926+
from litellm.constants import LITTELM_CLI_SERVICE_ACCOUNT_NAME
1927+
1928+
return cls(
1929+
api_key=LITTELM_CLI_SERVICE_ACCOUNT_NAME,
1930+
team_id=LITTELM_CLI_SERVICE_ACCOUNT_NAME,
1931+
key_alias=LITTELM_CLI_SERVICE_ACCOUNT_NAME,
1932+
team_alias=LITTELM_CLI_SERVICE_ACCOUNT_NAME,
1933+
)
19181934

19191935

19201936
class UserInfoResponse(LiteLLMPydanticObjectBase):

litellm/proxy/client/cli/commands/auth.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,20 @@ def login(ctx: click.Context):
6464

6565
base_url = ctx.obj["base_url"]
6666

67+
# Check if we have an existing key to regenerate
68+
existing_key = get_stored_api_key()
69+
6770
# Generate unique key ID for this login session
6871
key_id = f"sk-{str(uuid.uuid4())}"
6972

7073
try:
7174
# Construct SSO login URL with CLI source and pre-generated key
7275
sso_url = f"{base_url}/sso/key/generate?source={LITELLM_CLI_SOURCE_IDENTIFIER}&key={key_id}"
7376

77+
# If we have an existing key, include it so the server can regenerate it
78+
if existing_key:
79+
sso_url += f"&existing_key={existing_key}"
80+
7481
click.echo(f"Opening browser to: {sso_url}")
7582
click.echo("Please complete the SSO authentication in your browser...")
7683
click.echo(f"Session ID: {key_id}")

litellm/proxy/management_endpoints/ui_sso.py

Lines changed: 116 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def process_sso_jwt_access_token(
114114

115115
@router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False)
116116
async def google_login(
117-
request: Request, source: Optional[str] = None, key: Optional[str] = None
117+
request: Request, source: Optional[str] = None, key: Optional[str] = None, existing_key: Optional[str] = None
118118
): # noqa: PLR0915
119119
"""
120120
Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env
@@ -173,12 +173,14 @@ async def google_login(
173173
redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso(
174174
request=request,
175175
sso_callback_route="sso/callback",
176+
existing_key=existing_key,
176177
)
177178

178179
# Store CLI key in state for OAuth flow
179180
cli_state: Optional[str] = SSOAuthenticationHandler._get_cli_state(
180181
source=source,
181182
key=key,
183+
existing_key=existing_key,
182184
)
183185

184186
# check if user defined a custom auth sso sign in handler, if yes, use it
@@ -590,8 +592,12 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
590592
if state and state.startswith(f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:"):
591593
# Extract the key ID from the state
592594
key_id = state.split(":", 1)[1]
593-
verbose_proxy_logger.info(f"CLI SSO callback detected for key: {key_id}")
594-
return await cli_sso_callback(request, key=key_id)
595+
596+
# Get existing_key from query parameters if provided
597+
existing_key = request.query_params.get("existing_key")
598+
599+
verbose_proxy_logger.info(f"CLI SSO callback detected for key: {key_id}, existing_key: {existing_key}")
600+
return await cli_sso_callback(request, key=key_id, existing_key=existing_key)
595601

596602
from litellm.proxy._types import LiteLLM_JWTAuth
597603
from litellm.proxy.auth.handle_jwt import JWTHandler
@@ -678,13 +684,59 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
678684
)
679685

680686

681-
async def cli_sso_callback(request: Request, key: Optional[str] = None):
682-
"""CLI SSO callback - generates the key with pre-specified ID"""
683-
verbose_proxy_logger.info(f"CLI SSO callback for key: {key}")
687+
async def _regenerate_cli_key(existing_key: str, new_key: str) -> None:
688+
"""Regenerate an existing CLI key with a new token"""
689+
from litellm.proxy._types import RegenerateKeyRequest, UserAPIKeyAuth
690+
from litellm.proxy.management_endpoints.key_management_endpoints import (
691+
regenerate_key_fn,
692+
)
693+
694+
verbose_proxy_logger.info(f"Regenerating existing CLI key: {existing_key}")
695+
696+
admin_user_dict = UserAPIKeyAuth.get_litellm_cli_user_api_key_auth()
697+
698+
regenerate_request = RegenerateKeyRequest(
699+
key=existing_key,
700+
new_key=new_key,
701+
duration="24hr"
702+
)
703+
704+
await regenerate_key_fn(
705+
key=existing_key,
706+
data=regenerate_request,
707+
user_api_key_dict=admin_user_dict
708+
)
709+
710+
verbose_proxy_logger.info(f"Regenerated CLI key: {new_key}")
711+
684712

713+
async def _create_new_cli_key(key: str) -> None:
714+
"""Create a new CLI key"""
685715
from litellm.proxy.management_endpoints.key_management_endpoints import (
686716
generate_key_helper_fn,
687717
)
718+
719+
verbose_proxy_logger.info("Creating new CLI key")
720+
721+
await generate_key_helper_fn(
722+
request_type="key",
723+
duration="24hr",
724+
key_max_budget=litellm.max_ui_session_budget,
725+
aliases={},
726+
config={},
727+
spend=0,
728+
team_id="litellm-cli",
729+
table_name="key",
730+
token=key,
731+
)
732+
733+
verbose_proxy_logger.info(f"Created new CLI key: {key}")
734+
735+
736+
async def cli_sso_callback(request: Request, key: Optional[str] = None, existing_key: Optional[str] = None):
737+
"""CLI SSO callback - regenerates existing CLI key or creates new one"""
738+
verbose_proxy_logger.info(f"CLI SSO callback for key: {key}, existing_key: {existing_key}")
739+
688740
from litellm.proxy.proxy_server import prisma_client
689741

690742
if not key or not key.startswith("sk-"):
@@ -698,21 +750,11 @@ async def cli_sso_callback(request: Request, key: Optional[str] = None):
698750
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
699751
)
700752

701-
# Generate a simple key for CLI usage with the pre-specified key ID
702753
try:
703-
await generate_key_helper_fn(
704-
request_type="key",
705-
duration="24hr",
706-
key_max_budget=litellm.max_ui_session_budget,
707-
aliases={},
708-
config={},
709-
spend=0,
710-
team_id="litellm-cli",
711-
table_name="key",
712-
token=key, # Use the pre-specified key ID
713-
)
714-
715-
verbose_proxy_logger.info(f"Generated CLI key: {key}")
754+
if existing_key:
755+
await _regenerate_cli_key(existing_key, key)
756+
else:
757+
await _create_new_cli_key(key)
716758

717759
# Return success page
718760
from fastapi.responses import HTMLResponse
@@ -725,8 +767,8 @@ async def cli_sso_callback(request: Request, key: Optional[str] = None):
725767
return HTMLResponse(content=html_content, status_code=200)
726768

727769
except Exception as e:
728-
verbose_proxy_logger.error(f"Error generating CLI key: {e}")
729-
raise HTTPException(status_code=500, detail=f"Failed to generate key: {str(e)}")
770+
verbose_proxy_logger.error(f"Error with CLI key: {e}")
771+
raise HTTPException(status_code=500, detail=f"Failed to process CLI key: {str(e)}")
730772

731773

732774
@router.get("/sso/cli/poll/{key_id}", tags=["experimental"], include_in_schema=False)
@@ -993,20 +1035,51 @@ async def get_sso_login_redirect(
9931035
# or a cryptographicly signed state that we can verify stateless
9941036
# For simplification we are using a static state, this is not perfect but some
9951037
# SSO providers do not allow stateless verification
996-
redirect_params = {}
997-
state = os.getenv("GENERIC_CLIENT_STATE", None)
998-
999-
if state:
1000-
redirect_params["state"] = state
1001-
elif "okta" in generic_authorization_endpoint:
1002-
redirect_params["state"] = (
1003-
uuid.uuid4().hex
1004-
) # set state param for okta - required
1038+
redirect_params = SSOAuthenticationHandler._get_generic_sso_redirect_params(
1039+
state=state,
1040+
generic_authorization_endpoint=generic_authorization_endpoint
1041+
)
1042+
10051043
return await generic_sso.get_login_redirect(**redirect_params) # type: ignore
10061044
raise ValueError(
10071045
"Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso"
10081046
)
10091047

1048+
@staticmethod
1049+
def _get_generic_sso_redirect_params(
1050+
state: Optional[str] = None,
1051+
generic_authorization_endpoint: Optional[str] = None
1052+
) -> dict:
1053+
"""
1054+
Get redirect parameters for Generic SSO with proper state priority handling.
1055+
1056+
Priority order:
1057+
1. CLI state (if provided)
1058+
2. GENERIC_CLIENT_STATE environment variable
1059+
3. Generated UUID for Okta (if Okta endpoint detected)
1060+
1061+
Args:
1062+
state: Optional state parameter (e.g., CLI state)
1063+
generic_authorization_endpoint: Authorization endpoint URL
1064+
1065+
Returns:
1066+
dict: Redirect parameters for SSO login
1067+
"""
1068+
redirect_params = {}
1069+
1070+
if state:
1071+
# CLI state takes priority
1072+
# the litellm proxy cli sends the "state" parameter to the proxy server for auth. We should maintain the state parameter for the cli if it is provided
1073+
redirect_params["state"] = state
1074+
else:
1075+
generic_client_state = os.getenv("GENERIC_CLIENT_STATE", None)
1076+
if generic_client_state:
1077+
redirect_params["state"] = generic_client_state
1078+
elif generic_authorization_endpoint and "okta" in generic_authorization_endpoint:
1079+
redirect_params["state"] = uuid.uuid4().hex # set state param for okta - required
1080+
1081+
return redirect_params
1082+
10101083
@staticmethod
10111084
def should_use_sso_handler(
10121085
google_client_id: Optional[str] = None,
@@ -1025,6 +1098,7 @@ def should_use_sso_handler(
10251098
def get_redirect_url_for_sso(
10261099
request: Request,
10271100
sso_callback_route: str,
1101+
existing_key: Optional[str] = None,
10281102
) -> str:
10291103
"""
10301104
Get the redirect URL for SSO
@@ -1036,6 +1110,11 @@ def get_redirect_url_for_sso(
10361110
redirect_url += sso_callback_route
10371111
else:
10381112
redirect_url += "/" + sso_callback_route
1113+
1114+
# Append existing_key as query parameter if provided
1115+
if existing_key:
1116+
redirect_url += f"?existing_key={existing_key}"
1117+
10391118
return redirect_url
10401119

10411120
@staticmethod
@@ -1218,7 +1297,7 @@ def _cast_and_deepcopy_litellm_default_team_params(
12181297
return team_request
12191298

12201299
@staticmethod
1221-
def _get_cli_state(source: Optional[str], key: Optional[str]) -> Optional[str]:
1300+
def _get_cli_state(source: Optional[str], key: Optional[str], existing_key: Optional[str] = None) -> Optional[str]:
12221301
"""
12231302
Checks the request 'source' if a cli state token was passed in
12241303
@@ -1229,11 +1308,11 @@ def _get_cli_state(source: Optional[str], key: Optional[str]) -> Optional[str]:
12291308
LITELLM_CLI_SOURCE_IDENTIFIER,
12301309
)
12311310

1232-
return (
1233-
f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:{key}"
1234-
if source == LITELLM_CLI_SOURCE_IDENTIFIER and key
1235-
else None
1236-
)
1311+
if source == LITELLM_CLI_SOURCE_IDENTIFIER and key:
1312+
# Just use the key - existing_key will be passed separately via query params
1313+
return f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:{key}"
1314+
else:
1315+
return None
12371316

12381317
@staticmethod
12391318
async def get_redirect_response_from_openid( # noqa: PLR0915

0 commit comments

Comments
 (0)