Skip to content

Commit 00096f0

Browse files
committed
updates
1 parent 08e46f5 commit 00096f0

File tree

1 file changed

+44
-49
lines changed

1 file changed

+44
-49
lines changed

djangosaml2/backends.py

Lines changed: 44 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import logging
17+
from typing import Any, Optional, Tuple
1718

1819
from django.conf import settings
1920
from django.contrib import auth
@@ -37,18 +38,19 @@ def get_model(model_path):
3738

3839

3940
def get_saml_user_model():
41+
''' Returns the user model specified in the settings, or the default one from this Django installation '''
4042
if hasattr(settings, 'SAML_USER_MODEL'):
4143
return get_model(settings.SAML_USER_MODEL)
4244
return auth.get_user_model()
4345

4446

45-
def get_django_user_lookup_attribute(userModel):
47+
def get_django_user_lookup_attribute(userModel) -> str:
4648
if hasattr(settings, 'SAML_DJANGO_USER_MAIN_ATTRIBUTE'):
4749
return settings.SAML_DJANGO_USER_MAIN_ATTRIBUTE
4850
return getattr(userModel, 'USERNAME_FIELD', 'username')
4951

5052

51-
def set_attribute(obj, attr, value):
53+
def set_attribute(obj, attr, value) -> bool:
5254
""" Set an attribute of an object to a specific value, if it wasn't that already.
5355
Return True if the attribute was changed and False otherwise.
5456
"""
@@ -61,28 +63,42 @@ def set_attribute(obj, attr, value):
6163
return False
6264

6365

66+
UserModel = get_saml_user_model()
67+
68+
6469
class Saml2Backend(ModelBackend):
65-
def __init__(self):
66-
super().__init__()
67-
self.UserModel = get_saml_user_model()
70+
def is_authorized(self, attributes, attribute_mapping) -> bool:
71+
""" Hook to allow custom authorization policies based on SAML attributes. """
72+
return True
6873

69-
def _extract_user_identifier_value(self, session_info, attributes, attribute_mapping):
70-
""" Extract the user identifier value from the saml attributes.
71-
Returns None if no identifier could be extracted from the saml payload.
74+
def clean_attributes(self, attributes: dict) -> dict:
75+
"""Hook to clean attributes from the SAML response. """
76+
return attributes
77+
78+
def clean_user_main_attribute(self, main_attribute):
79+
""" Clean the extracted user identifying value. No-op by default. """
80+
return main_attribute
81+
82+
def _extract_user_identifier_params(self, session_info, attributes, attribute_mapping) -> Tuple[str, Optionaly[Any]]:
83+
""" Returns the attribute to perform a user lookup on, and the value to use for it.
84+
The value could be the name_id, or any other saml attribute from the request.
7285
"""
86+
# Lookup key
87+
user_lookup_key = get_django_user_lookup_attribute(UserModel)
88+
89+
# Lookup value
7390
if getattr(settings, 'SAML_USE_NAME_ID_AS_USERNAME', False):
7491
if 'name_id' in session_info:
7592
logger.debug('name_id: %s', session_info['name_id'])
76-
saml_user_identifier = session_info['name_id'].text
93+
user_lookup_value = session_info['name_id'].text
7794
else:
7895
logger.error('The nameid is not available. Cannot find user without a nameid.')
79-
saml_user_identifier = None
96+
user_lookup_value = None
8097
else:
8198
# Obtain the value of the custom attribute to use
82-
user_lookup_attribute = get_django_user_lookup_attribute(self.UserModel)
83-
saml_user_identifier = self._get_attribute_value(user_lookup_attribute, attributes, attribute_mapping)
99+
user_lookup_value = self._get_attribute_value(user_lookup_key, attributes, attribute_mapping)
84100

85-
return self.clean_user_main_attribute(saml_user_identifier)
101+
return user_lookup_key, self.clean_user_main_attribute(user_lookup_value)
86102

87103
def _get_attribute_value(self, django_field, attributes, attribute_mapping):
88104
saml_attribute = None
@@ -96,31 +112,30 @@ def _get_attribute_value(self, django_field, attributes, attribute_mapping):
96112
'session is expired.')
97113
return saml_attribute
98114

99-
def get_or_create_user(self, user_identifier, create_unknown_user, **kwargs):
115+
def get_or_create_user(self, user_lookup_key, user_lookup_value, create_unknown_user, **kwargs) -> Tuple[Optional[settings.AUTH_USER_MODEL], bool]:
100116
""" Look up the user to authenticate. If he doesn't exist, this method creates him (if so desired).
101117
The default implementation looks only at the user_identifier. Override this method in order to do more complex behaviour,
102118
e.g. customize this per IdP. The kwargs contain these additional params: session_info, attribute_mapping, attributes, request.
103119
The identity provider id can be found in kwargs['session_info']['issuer]
104120
"""
105-
# Construct query parameters to query the userModel with.
106-
user_lookup_attribute = get_django_user_lookup_attribute(self.UserModel)
121+
# Construct query parameters to query the userModel with. An additional lookup modifier could be specified in the settings.
107122
user_query_args = {
108-
user_lookup_attribute + getattr(settings, 'SAML_DJANGO_USER_MAIN_ATTRIBUTE_LOOKUP', ''): user_identifier
123+
user_lookup_key + getattr(settings, 'SAML_DJANGO_USER_MAIN_ATTRIBUTE_LOOKUP', ''): user_lookup_value
109124
}
110125

111126
# Lookup existing user
112127
user, created = None, False
113128
try:
114-
user = self.UserModel.objects.get(**user_query_args)
129+
user = UserModel.objects.get(**user_query_args)
115130
except MultipleObjectsReturned:
116131
logger.error("Multiple users match, lookup: %s", user_query_args)
117-
except self.UserModel.DoesNotExist:
132+
except UserModel.DoesNotExist:
118133
logger.error('The user does not exist, lookup: %s', user_query_args)
119134

120135
# Create new one if desired by settings
121136
if create_unknown_user:
122137
try:
123-
user, created = self.UserModel.objects.get_or_create(**user_query_args, defaults={user_lookup_attribute: user_identifier})
138+
user, created = UserModel.objects.get_or_create(**user_query_args, defaults={user_lookup_key: user_lookup_value})
124139
except Exception as e:
125140
logger.error('Could not create new user: %s', e)
126141

@@ -149,48 +164,28 @@ def authenticate(self, request, session_info=None, attribute_mapping=None, creat
149164
logger.error('Request not authorized')
150165
return None
151166

152-
user_identifier = self._extract_user_identifier_value(session_info, attributes, attribute_mapping)
153-
if not user_identifier:
167+
user_lookup_key, user_lookup_value = self._extract_user_identifier_params(session_info, attributes, attribute_mapping)
168+
if not user_lookup_value:
154169
logger.error('Could not determine user identifier')
155170
return None
156171

157172
user, created = self.get_or_create_user(
158-
user_identifier, create_unknown_user,
173+
user_lookup_key, user_lookup_value, create_unknown_user,
159174
request=request, session_info=session_info, attributes=attributes, attribute_mapping=attribute_mapping
160175
)
161176

162177
# Update user with new attributes from incoming request
163178
if user is not None:
164179
user = self.update_user(user, attributes, attribute_mapping, force_save=created)
165-
logger.debug('User updated with incoming attributes')
166180

167181
return user
168182

169-
def is_authorized(self, attributes, attribute_mapping):
170-
"""Hook to allow custom authorization policies based on
171-
SAML attributes.
172-
"""
173-
return True
174-
175-
def clean_attributes(self, attributes):
176-
"""Hook to clean attributes from the SAML response."""
177-
return attributes
178-
179-
def clean_user_main_attribute(self, main_attribute):
180-
"""Performs any cleaning on the user main attribute (which
181-
usually is "username") prior to using it to get or
182-
create the user object. Returns the cleaned attribute.
183-
184-
By default, returns the attribute unchanged.
185-
"""
186-
return main_attribute
187-
188183
def update_user(self, user, attributes, attribute_mapping, force_save=False):
189-
"""Update a user with a set of attributes and returns the updated user.
184+
""" Update a user with a set of attributes and returns the updated user.
190185
191-
By default it uses a mapping defined in the settings constant
192-
SAML_ATTRIBUTE_MAPPING. For each attribute, if the user object has
193-
that field defined it will be set.
186+
By default it uses a mapping defined in the settings constant
187+
SAML_ATTRIBUTE_MAPPING. For each attribute, if the user object has
188+
that field defined it will be set.
194189
"""
195190
if not attribute_mapping:
196191
return user
@@ -214,8 +209,7 @@ def update_user(self, user, attributes, attribute_mapping, force_save=False):
214209

215210
user_modified = user_modified or modified
216211
else:
217-
logger.debug(
218-
'Could not find attribute "%s" on user "%s"', attr, user)
212+
logger.debug('Could not find attribute "%s" on user "%s"', attr, user)
219213

220214
logger.debug('Sending the pre_save signal')
221215
signal_modified = any(
@@ -228,5 +222,6 @@ def update_user(self, user, attributes, attribute_mapping, force_save=False):
228222

229223
if user_modified or signal_modified or force_save:
230224
user.save()
225+
logger.debug('User updated with incoming attributes')
231226

232227
return user

0 commit comments

Comments
 (0)