|
4 | 4 | from django.conf import settings |
5 | 5 | from django.test import override_settings |
6 | 6 |
|
7 | | -from ansible_base.authentication.social_auth import AuthenticatorStorage, AuthenticatorStrategy, SocialAuthValidateCallbackMixin |
| 7 | +from ansible_base.authentication.social_auth import ( |
| 8 | + AuthenticatorStorage, |
| 9 | + AuthenticatorStrategy, |
| 10 | + SocialAuthMixin, |
| 11 | + SocialAuthValidateCallbackMixin, |
| 12 | + create_user_claims_pipeline, |
| 13 | +) |
8 | 14 |
|
9 | 15 |
|
10 | 16 | @mock.patch("ansible_base.authentication.social_auth.logger") |
@@ -75,3 +81,51 @@ def test_social_auth_validate_callback_mixin(mocked_generate_slug, mocked_revers |
75 | 81 | # should always call reverse if no callback url |
76 | 82 | if has_instance and 'configuration' in test_data and not test_data.get('configuration', {}).get('CALLBACK_URL'): |
77 | 83 | assert mocked_reverse.called |
| 84 | + |
| 85 | + |
| 86 | +@pytest.mark.parametrize( |
| 87 | + "groups_claim,returned_groups,expected_groups", |
| 88 | + [ |
| 89 | + (None, ["mygroup"], ["mygroup"]), |
| 90 | + ("groups", ["mygroup"], ["mygroup"]), |
| 91 | + (None, None, []), |
| 92 | + ("groups", None, []), |
| 93 | + ], |
| 94 | +) |
| 95 | +@mock.patch("ansible_base.authentication.utils.claims.update_user_claims") |
| 96 | +def test_create_user_claims_pipeline(mock_update_user_claims, groups_claim, returned_groups, expected_groups): |
| 97 | + ''' |
| 98 | + We are testing to see if extracting groups from a claim is working correctly |
| 99 | + ''' |
| 100 | + |
| 101 | + class MockBackend(SocialAuthMixin): |
| 102 | + database_instance = None |
| 103 | + |
| 104 | + def __init__(self, groups_claim=None): |
| 105 | + if groups_claim is not None: |
| 106 | + self.groups_claim = groups_claim |
| 107 | + |
| 108 | + def get_user_groups(self, extra_groups=[]): |
| 109 | + return extra_groups |
| 110 | + |
| 111 | + backend = MockBackend(groups_claim=groups_claim) |
| 112 | + |
| 113 | + rData = {} |
| 114 | + if returned_groups is not None: |
| 115 | + rData[backend.groups_claim] = returned_groups |
| 116 | + |
| 117 | + user = { |
| 118 | + 'auth_time': "2024-11-07T05:19:08.224936Z", |
| 119 | + 'id_token': "asdf", |
| 120 | + 'refresh_token': None, |
| 121 | + 'id': "ccd2cf13-d927-41ad-cd8c-adb18b2e5f78", |
| 122 | + 'access_token': "asdf", |
| 123 | + 'token_type': "Bearer", |
| 124 | + } |
| 125 | + |
| 126 | + create_user_claims_pipeline(backend=backend, response=rData, user=user) |
| 127 | + |
| 128 | + assert mock_update_user_claims.called |
| 129 | + call_args = mock_update_user_claims.call_args |
| 130 | + |
| 131 | + assert call_args == ((user, None, expected_groups),) |
0 commit comments