Skip to content

Commit 31c95bb

Browse files
authored
Support modifying groups claim for social auth (ansible#640)
By default social auth hard codes the groups claim to `Group`. In general IDPs use `groups` for returning the user group membership. This PR support setting the groups claim field. It also defaults the users list of groups to an empty list(instead of `None`) if no group claim is found.
1 parent 845b3e1 commit 31c95bb

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

ansible_base/authentication/authenticator_plugins/oidc.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ class OpenIdConnectConfiguration(BaseAuthenticatorConfiguration):
201201
ui_field_label=_("Username Key"),
202202
)
203203

204+
GROUPS_CLAIM = CharField(
205+
help_text=_("The JSON key used to extract the user's groups from the ID token or userinfo endpoint."),
206+
required=False,
207+
allow_null=True,
208+
default="Group",
209+
ui_field_label=_("Groups Claim"),
210+
)
211+
204212

205213
class AuthenticatorPlugin(SocialAuthMixin, OpenIdConnectAuth, AbstractAuthenticatorPlugin):
206214
configuration_class = OpenIdConnectConfiguration
@@ -209,6 +217,10 @@ class AuthenticatorPlugin(SocialAuthMixin, OpenIdConnectAuth, AbstractAuthentica
209217
category = "sso"
210218
configuration_encrypted_fields = ['SECRET']
211219

220+
@property
221+
def groups_claim(self):
222+
return self.setting('GROUPS_CLAIM')
223+
212224
def extra_data(self, user, backend, response, *args, **kwargs):
213225
for perm in ["is_superuser", get_setting('ANSIBLE_BASE_SOCIAL_AUDITOR_FLAG')]:
214226
if perm in response:

ansible_base/authentication/social_auth.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __init__(self, storage, request=None, tpl=None, additional_settings={}):
135135
class SocialAuthMixin:
136136
configuration_encrypted_fields = []
137137
logger = None
138+
groups_claim = "Group"
138139

139140
def __init__(self, *args, **kwargs):
140141
# social auth expects the first arg to be a strategy instance. Since this has
@@ -190,7 +191,9 @@ def validate(self, serializer, data):
190191
def create_user_claims_pipeline(*args, backend, response, **kwargs):
191192
from ansible_base.authentication.utils.claims import update_user_claims
192193

193-
extra_groups = response["Group"] if "Group" in response else None
194+
groups_claim = backend.groups_claim if backend.groups_claim is not None else "Group"
195+
196+
extra_groups = response[groups_claim] if groups_claim in response else []
194197
user = update_user_claims(kwargs["user"], backend.database_instance, backend.get_user_groups(extra_groups))
195198
if user is None:
196199
return SOCIAL_AUTH_PIPELINE_FAILED_STATUS

test_app/tests/authentication/test_social_auth.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
from django.conf import settings
55
from django.test import override_settings
66

7-
from ansible_base.authentication.social_auth import AuthenticatorStorage, AuthenticatorStrategy, SocialAuthValidateCallbackMixin
7+
from ansible_base.authentication.social_auth import (
8+
AuthenticatorStorage,
9+
AuthenticatorStrategy,
10+
SocialAuthMixin,
11+
SocialAuthValidateCallbackMixin,
12+
create_user_claims_pipeline,
13+
)
814

915

1016
@mock.patch("ansible_base.authentication.social_auth.logger")
@@ -75,3 +81,51 @@ def test_social_auth_validate_callback_mixin(mocked_generate_slug, mocked_revers
7581
# should always call reverse if no callback url
7682
if has_instance and 'configuration' in test_data and not test_data.get('configuration', {}).get('CALLBACK_URL'):
7783
assert mocked_reverse.called
84+
85+
86+
@pytest.mark.parametrize(
87+
"groups_claim,returned_groups,expected_groups",
88+
[
89+
(None, ["mygroup"], ["mygroup"]),
90+
("groups", ["mygroup"], ["mygroup"]),
91+
(None, None, []),
92+
("groups", None, []),
93+
],
94+
)
95+
@mock.patch("ansible_base.authentication.utils.claims.update_user_claims")
96+
def test_create_user_claims_pipeline(mock_update_user_claims, groups_claim, returned_groups, expected_groups):
97+
'''
98+
We are testing to see if extracting groups from a claim is working correctly
99+
'''
100+
101+
class MockBackend(SocialAuthMixin):
102+
database_instance = None
103+
104+
def __init__(self, groups_claim=None):
105+
if groups_claim is not None:
106+
self.groups_claim = groups_claim
107+
108+
def get_user_groups(self, extra_groups=[]):
109+
return extra_groups
110+
111+
backend = MockBackend(groups_claim=groups_claim)
112+
113+
rData = {}
114+
if returned_groups is not None:
115+
rData[backend.groups_claim] = returned_groups
116+
117+
user = {
118+
'auth_time': "2024-11-07T05:19:08.224936Z",
119+
'id_token': "asdf",
120+
'refresh_token': None,
121+
'id': "ccd2cf13-d927-41ad-cd8c-adb18b2e5f78",
122+
'access_token': "asdf",
123+
'token_type': "Bearer",
124+
}
125+
126+
create_user_claims_pipeline(backend=backend, response=rData, user=user)
127+
128+
assert mock_update_user_claims.called
129+
call_args = mock_update_user_claims.call_args
130+
131+
assert call_args == ((user, None, expected_groups),)

0 commit comments

Comments
 (0)