Skip to content

Commit d28ecbc

Browse files
feat(ui_sso.py): support mapping app roles from azure entra id to litellm user roles
Closes LIT-1228
1 parent 1c56a0d commit d28ecbc

File tree

2 files changed

+115
-5
lines changed

2 files changed

+115
-5
lines changed

litellm/proxy/management_endpoints/types.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,48 @@
44
Might include fastapi/proxy requirements.txt related imports
55
"""
66

7-
from typing import List
7+
from typing import List, Optional, cast
88

99
from fastapi_sso.sso.base import OpenID
1010

11+
from litellm.proxy._types import LitellmUserRoles
12+
13+
14+
def is_valid_litellm_user_role(role_str: str) -> bool:
15+
"""
16+
Check if a string is a valid LitellmUserRoles enum value (case-insensitive).
17+
18+
Args:
19+
role_str: String to validate (e.g., "proxy_admin", "PROXY_ADMIN", "internal_user")
20+
21+
Returns:
22+
True if the string matches a valid LitellmUserRoles value, False otherwise
23+
"""
24+
try:
25+
# Use _value2member_map_ for O(1) lookup, case-insensitive
26+
return role_str.lower() in LitellmUserRoles._value2member_map_
27+
except Exception:
28+
return False
29+
30+
31+
def get_litellm_user_role(role_str: str) -> Optional[LitellmUserRoles]:
32+
"""
33+
Convert a string to a LitellmUserRoles enum if valid (case-insensitive).
34+
35+
Args:
36+
role_str: String to convert (e.g., "proxy_admin", "PROXY_ADMIN", "internal_user")
37+
38+
Returns:
39+
LitellmUserRoles enum if valid, None otherwise
40+
"""
41+
try:
42+
# Use _value2member_map_ for O(1) lookup, case-insensitive
43+
result = LitellmUserRoles._value2member_map_.get(role_str.lower())
44+
return cast(Optional[LitellmUserRoles], result)
45+
except Exception:
46+
return None
47+
1148

1249
class CustomOpenID(OpenID):
1350
team_ids: List[str]
51+
user_role: Optional[LitellmUserRoles] = None

litellm/proxy/management_endpoints/ui_sso.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
has_admin_ui_access,
5959
)
6060
from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add
61-
from litellm.proxy.management_endpoints.types import CustomOpenID
61+
from litellm.proxy.management_endpoints.types import CustomOpenID, get_litellm_user_role
6262
from litellm.proxy.utils import (
6363
PrismaClient,
6464
ProxyLogging,
@@ -277,6 +277,7 @@ def generic_response_convertor(
277277
last_name=response.get(generic_user_last_name_attribute_name),
278278
provider=response.get(generic_provider_attribute_name),
279279
team_ids=all_teams,
280+
user_role=None,
280281
)
281282

282283

@@ -1145,7 +1146,7 @@ def get_redirect_url_for_sso(
11451146
) -> str:
11461147
"""
11471148
Get the redirect URL for SSO
1148-
1149+
11491150
Note: existing_key is not added to the URL to avoid changing the callback URL.
11501151
It should be passed via the state parameter instead.
11511152
"""
@@ -1348,7 +1349,7 @@ def _get_cli_state(
13481349
Checks the request 'source' if a cli state token was passed in
13491350
13501351
This is used to authenticate through the CLI login flow.
1351-
1352+
13521353
The state parameter format is: {PREFIX}:{key}:{existing_key}
13531354
- If existing_key is provided, it's included in the state
13541355
- The state parameter is used to pass data through the OAuth flow without changing the callback URL
@@ -1673,22 +1674,49 @@ async def get_microsoft_callback_response(
16731674
access_token=microsoft_sso.access_token
16741675
)
16751676

1677+
# Extract app roles from the id_token JWT
1678+
app_roles = MicrosoftSSOHandler.get_app_roles_from_id_token(
1679+
id_token=microsoft_sso.id_token
1680+
)
1681+
verbose_proxy_logger.debug(f"Extracted app roles from id_token: {app_roles}")
1682+
1683+
# Combine groups and app roles
1684+
user_role: Optional[LitellmUserRoles] = None
1685+
if app_roles:
1686+
# Check if any app role is a valid LitellmUserRoles
1687+
for role_str in app_roles:
1688+
role = get_litellm_user_role(role_str)
1689+
if role is not None:
1690+
user_role = role
1691+
verbose_proxy_logger.debug(
1692+
f"Found valid LitellmUserRoles '{role.value}' in app_roles"
1693+
)
1694+
break
1695+
1696+
verbose_proxy_logger.debug(
1697+
f"Combined team_ids (groups + app roles): {user_team_ids}"
1698+
)
1699+
16761700
# if user is trying to get the raw sso response for debugging, return the raw sso response
16771701
if return_raw_sso_response:
16781702
original_msft_result[MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY] = (
16791703
user_team_ids
16801704
)
1705+
original_msft_result["app_roles"] = app_roles
16811706
return original_msft_result or {}
16821707

16831708
result = MicrosoftSSOHandler.openid_from_response(
16841709
response=original_msft_result,
16851710
team_ids=user_team_ids,
1711+
user_role=user_role,
16861712
)
16871713
return result
16881714

16891715
@staticmethod
16901716
def openid_from_response(
1691-
response: Optional[dict], team_ids: List[str]
1717+
response: Optional[dict],
1718+
team_ids: List[str],
1719+
user_role: Optional[LitellmUserRoles],
16921720
) -> CustomOpenID:
16931721
response = response or {}
16941722
verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}")
@@ -1700,10 +1728,54 @@ def openid_from_response(
17001728
first_name=response.get("givenName"),
17011729
last_name=response.get("surname"),
17021730
team_ids=team_ids,
1731+
user_role=user_role,
17031732
)
17041733
verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}")
17051734
return openid_response
17061735

1736+
@staticmethod
1737+
def get_app_roles_from_id_token(id_token: Optional[str]) -> List[str]:
1738+
"""
1739+
Extract app roles from the Microsoft Entra ID (Azure AD) id_token JWT.
1740+
1741+
App roles are assigned in the Azure AD Enterprise Application and appear
1742+
in the 'roles' claim of the id_token.
1743+
1744+
Args:
1745+
id_token (Optional[str]): The JWT id_token from Microsoft SSO
1746+
1747+
Returns:
1748+
List[str]: List of app role names assigned to the user
1749+
"""
1750+
if not id_token:
1751+
verbose_proxy_logger.debug("No id_token provided for app role extraction")
1752+
return []
1753+
1754+
try:
1755+
import jwt
1756+
1757+
# Decode the JWT without signature verification
1758+
# (signature is already verified by fastapi_sso)
1759+
decoded_token = jwt.decode(id_token, options={"verify_signature": False})
1760+
1761+
# Extract roles claim from the token
1762+
roles = decoded_token.get("roles", [])
1763+
1764+
if roles and isinstance(roles, list):
1765+
verbose_proxy_logger.debug(
1766+
f"Found {len(roles)} app role(s) in id_token: {roles}"
1767+
)
1768+
return roles
1769+
else:
1770+
verbose_proxy_logger.debug(
1771+
"No app roles found in id_token or roles claim is not a list"
1772+
)
1773+
return []
1774+
1775+
except Exception as e:
1776+
verbose_proxy_logger.error(f"Error extracting app roles from id_token: {e}")
1777+
return []
1778+
17071779
@staticmethod
17081780
async def get_user_groups_from_graph_api(
17091781
access_token: Optional[str] = None,

0 commit comments

Comments
 (0)