14
14
# limitations under the License.
15
15
16
16
import logging
17
+ from typing import Any , Optional , Tuple
17
18
18
19
from django .conf import settings
19
20
from django .contrib import auth
@@ -37,18 +38,19 @@ def get_model(model_path):
37
38
38
39
39
40
def get_saml_user_model ():
41
+ ''' Returns the user model specified in the settings, or the default one from this Django installation '''
40
42
if hasattr (settings , 'SAML_USER_MODEL' ):
41
43
return get_model (settings .SAML_USER_MODEL )
42
44
return auth .get_user_model ()
43
45
44
46
45
- def get_django_user_lookup_attribute (userModel ):
47
+ def get_django_user_lookup_attribute (userModel ) -> str :
46
48
if hasattr (settings , 'SAML_DJANGO_USER_MAIN_ATTRIBUTE' ):
47
49
return settings .SAML_DJANGO_USER_MAIN_ATTRIBUTE
48
50
return getattr (userModel , 'USERNAME_FIELD' , 'username' )
49
51
50
52
51
- def set_attribute (obj , attr , value ):
53
+ def set_attribute (obj , attr , value ) -> bool :
52
54
""" Set an attribute of an object to a specific value, if it wasn't that already.
53
55
Return True if the attribute was changed and False otherwise.
54
56
"""
@@ -61,28 +63,42 @@ def set_attribute(obj, attr, value):
61
63
return False
62
64
63
65
66
+ UserModel = get_saml_user_model ()
67
+
68
+
64
69
class Saml2Backend (ModelBackend ):
65
- def __init__ (self ) :
66
- super (). __init__ ()
67
- self . UserModel = get_saml_user_model ()
70
+ def is_authorized (self , attributes , attribute_mapping ) -> bool :
71
+ """ Hook to allow custom authorization policies based on SAML attributes. """
72
+ return True
68
73
69
- def _extract_user_identifier_value (self , session_info , attributes , attribute_mapping ):
70
- """ Extract the user identifier value from the saml attributes.
71
- Returns None if no identifier could be extracted from the saml payload.
74
+ def clean_attributes (self , attributes : dict ) -> dict :
75
+ """Hook to clean attributes from the SAML response. """
76
+ return attributes
77
+
78
+ def clean_user_main_attribute (self , main_attribute ):
79
+ """ Clean the extracted user identifying value. No-op by default. """
80
+ return main_attribute
81
+
82
+ def _extract_user_identifier_params (self , session_info , attributes , attribute_mapping ) -> Tuple [str , Optionaly [Any ]]:
83
+ """ Returns the attribute to perform a user lookup on, and the value to use for it.
84
+ The value could be the name_id, or any other saml attribute from the request.
72
85
"""
86
+ # Lookup key
87
+ user_lookup_key = get_django_user_lookup_attribute (UserModel )
88
+
89
+ # Lookup value
73
90
if getattr (settings , 'SAML_USE_NAME_ID_AS_USERNAME' , False ):
74
91
if 'name_id' in session_info :
75
92
logger .debug ('name_id: %s' , session_info ['name_id' ])
76
- saml_user_identifier = session_info ['name_id' ].text
93
+ user_lookup_value = session_info ['name_id' ].text
77
94
else :
78
95
logger .error ('The nameid is not available. Cannot find user without a nameid.' )
79
- saml_user_identifier = None
96
+ user_lookup_value = None
80
97
else :
81
98
# Obtain the value of the custom attribute to use
82
- user_lookup_attribute = get_django_user_lookup_attribute (self .UserModel )
83
- saml_user_identifier = self ._get_attribute_value (user_lookup_attribute , attributes , attribute_mapping )
99
+ user_lookup_value = self ._get_attribute_value (user_lookup_key , attributes , attribute_mapping )
84
100
85
- return self .clean_user_main_attribute (saml_user_identifier )
101
+ return user_lookup_key , self .clean_user_main_attribute (user_lookup_value )
86
102
87
103
def _get_attribute_value (self , django_field , attributes , attribute_mapping ):
88
104
saml_attribute = None
@@ -96,31 +112,30 @@ def _get_attribute_value(self, django_field, attributes, attribute_mapping):
96
112
'session is expired.' )
97
113
return saml_attribute
98
114
99
- def get_or_create_user (self , user_identifier , create_unknown_user , ** kwargs ):
115
+ def get_or_create_user (self , user_lookup_key , user_lookup_value , create_unknown_user , ** kwargs ) -> Tuple [ Optional [ settings . AUTH_USER_MODEL ], bool ] :
100
116
""" Look up the user to authenticate. If he doesn't exist, this method creates him (if so desired).
101
117
The default implementation looks only at the user_identifier. Override this method in order to do more complex behaviour,
102
118
e.g. customize this per IdP. The kwargs contain these additional params: session_info, attribute_mapping, attributes, request.
103
119
The identity provider id can be found in kwargs['session_info']['issuer]
104
120
"""
105
- # Construct query parameters to query the userModel with.
106
- user_lookup_attribute = get_django_user_lookup_attribute (self .UserModel )
121
+ # Construct query parameters to query the userModel with. An additional lookup modifier could be specified in the settings.
107
122
user_query_args = {
108
- user_lookup_attribute + getattr (settings , 'SAML_DJANGO_USER_MAIN_ATTRIBUTE_LOOKUP' , '' ): user_identifier
123
+ user_lookup_key + getattr (settings , 'SAML_DJANGO_USER_MAIN_ATTRIBUTE_LOOKUP' , '' ): user_lookup_value
109
124
}
110
125
111
126
# Lookup existing user
112
127
user , created = None , False
113
128
try :
114
- user = self . UserModel .objects .get (** user_query_args )
129
+ user = UserModel .objects .get (** user_query_args )
115
130
except MultipleObjectsReturned :
116
131
logger .error ("Multiple users match, lookup: %s" , user_query_args )
117
- except self . UserModel .DoesNotExist :
132
+ except UserModel .DoesNotExist :
118
133
logger .error ('The user does not exist, lookup: %s' , user_query_args )
119
134
120
135
# Create new one if desired by settings
121
136
if create_unknown_user :
122
137
try :
123
- user , created = self . UserModel .objects .get_or_create (** user_query_args , defaults = {user_lookup_attribute : user_identifier })
138
+ user , created = UserModel .objects .get_or_create (** user_query_args , defaults = {user_lookup_key : user_lookup_value })
124
139
except Exception as e :
125
140
logger .error ('Could not create new user: %s' , e )
126
141
@@ -149,48 +164,28 @@ def authenticate(self, request, session_info=None, attribute_mapping=None, creat
149
164
logger .error ('Request not authorized' )
150
165
return None
151
166
152
- user_identifier = self ._extract_user_identifier_value (session_info , attributes , attribute_mapping )
153
- if not user_identifier :
167
+ user_lookup_key , user_lookup_value = self ._extract_user_identifier_params (session_info , attributes , attribute_mapping )
168
+ if not user_lookup_value :
154
169
logger .error ('Could not determine user identifier' )
155
170
return None
156
171
157
172
user , created = self .get_or_create_user (
158
- user_identifier , create_unknown_user ,
173
+ user_lookup_key , user_lookup_value , create_unknown_user ,
159
174
request = request , session_info = session_info , attributes = attributes , attribute_mapping = attribute_mapping
160
175
)
161
176
162
177
# Update user with new attributes from incoming request
163
178
if user is not None :
164
179
user = self .update_user (user , attributes , attribute_mapping , force_save = created )
165
- logger .debug ('User updated with incoming attributes' )
166
180
167
181
return user
168
182
169
- def is_authorized (self , attributes , attribute_mapping ):
170
- """Hook to allow custom authorization policies based on
171
- SAML attributes.
172
- """
173
- return True
174
-
175
- def clean_attributes (self , attributes ):
176
- """Hook to clean attributes from the SAML response."""
177
- return attributes
178
-
179
- def clean_user_main_attribute (self , main_attribute ):
180
- """Performs any cleaning on the user main attribute (which
181
- usually is "username") prior to using it to get or
182
- create the user object. Returns the cleaned attribute.
183
-
184
- By default, returns the attribute unchanged.
185
- """
186
- return main_attribute
187
-
188
183
def update_user (self , user , attributes , attribute_mapping , force_save = False ):
189
- """Update a user with a set of attributes and returns the updated user.
184
+ """ Update a user with a set of attributes and returns the updated user.
190
185
191
- By default it uses a mapping defined in the settings constant
192
- SAML_ATTRIBUTE_MAPPING. For each attribute, if the user object has
193
- that field defined it will be set.
186
+ By default it uses a mapping defined in the settings constant
187
+ SAML_ATTRIBUTE_MAPPING. For each attribute, if the user object has
188
+ that field defined it will be set.
194
189
"""
195
190
if not attribute_mapping :
196
191
return user
@@ -214,8 +209,7 @@ def update_user(self, user, attributes, attribute_mapping, force_save=False):
214
209
215
210
user_modified = user_modified or modified
216
211
else :
217
- logger .debug (
218
- 'Could not find attribute "%s" on user "%s"' , attr , user )
212
+ logger .debug ('Could not find attribute "%s" on user "%s"' , attr , user )
219
213
220
214
logger .debug ('Sending the pre_save signal' )
221
215
signal_modified = any (
@@ -228,5 +222,6 @@ def update_user(self, user, attributes, attribute_mapping, force_save=False):
228
222
229
223
if user_modified or signal_modified or force_save :
230
224
user .save ()
225
+ logger .debug ('User updated with incoming attributes' )
231
226
232
227
return user
0 commit comments