Skip to content

Commit 8c0b6f1

Browse files
committed
Merge branch 'upstream' into attribute-generation
2 parents f772efd + 0e5afd4 commit 8c0b6f1

File tree

4 files changed

+56
-3
lines changed

4 files changed

+56
-3
lines changed

doc/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,18 @@ config:
242242
idp-entity-id1
243243
sp-entity-id1:
244244
exclude: ["givenName"]
245+
246+
247+
The custom_attribute_release mechanism supports defaults based on idp and sp entity Id by specifying "" or "default"
248+
as the key in the dict. For instance in order to exclude givenName for any sp or idp do this:
249+
250+
```yaml
251+
config:
252+
config: [...]
253+
custom_attribute_release:
254+
"default":
255+
"":
256+
exclude: ["givenName"]
245257

246258

247259
#### Backend

src/satosa/backends/saml2.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,10 +214,16 @@ def _translate_response(self, response, state):
214214
issuer = response.response.issuer.text
215215

216216
auth_info = AuthenticationInformation(auth_class_ref, timestamp, issuer)
217-
internal_resp = InternalResponse(auth_info=auth_info)
217+
internal_resp = SAMLInternalResponse(auth_info=auth_info)
218218

219219
internal_resp.user_id = response.get_subject().text
220220
internal_resp.attributes = self.converter.to_internal(self.attribute_profile, response.ava)
221+
222+
# The SAML response may not include a NameID
223+
try:
224+
internal_resp.name_id = response.assertion.subject.name_id
225+
except AttributeError:
226+
pass
221227

222228
satosa_logging(logger, logging.DEBUG, "received attributes:\n%s" % json.dumps(response.ava, indent=4), state)
223229
return internal_resp
@@ -315,3 +321,32 @@ def get_metadata_desc(self):
315321

316322
entity_descriptions.append(description)
317323
return entity_descriptions
324+
325+
class SAMLInternalResponse(InternalResponse):
326+
"""
327+
Like the parent InternalResponse, holds internal representation of
328+
service related data, but includes additional details relevant to
329+
SAML interoperability.
330+
331+
:type name_id: instance of saml2.saml.NameID from pysaml2
332+
"""
333+
def __init__(self, auth_info=None):
334+
super().__init__(auth_info)
335+
336+
self.name_id = None
337+
338+
def to_dict(self):
339+
"""
340+
Converts a SAMLInternalResponse object to a dict
341+
:rtype: dict[str, dict[str, str] | str]
342+
:return: A dict representation of the object
343+
"""
344+
_dict = super().to_dict()
345+
346+
if self.name_id:
347+
_dict['name_id'] = {self.name_id.format : self.name_id.text}
348+
else:
349+
_dict['name_id'] = None
350+
351+
return _dict
352+

src/satosa/frontends/saml2.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..response import Response
2222
from ..response import ServiceError
2323
from ..saml_util import make_saml_response
24+
from ..util import get_dict_defaults
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -271,9 +272,10 @@ def _handle_authn_response(self, context, internal_response, idp):
271272
else:
272273
auth_info["class_ref"] = internal_response.auth_info.auth_class_ref
273274

275+
auth_info["authn_auth"] = internal_response.auth_info.issuer
276+
274277
if self.custom_attribute_release:
275-
custom_release_per_idp = self.custom_attribute_release.get(internal_response.auth_info.issuer, {})
276-
custom_release = custom_release_per_idp.get(resp_args["sp_entity_id"], {})
278+
custom_release = get_dict_defaults(self.custom_attribute_release, internal_response.auth_info.issuer, resp_args["sp_entity_id"])
277279
attributes_to_remove = custom_release.get("exclude", [])
278280
for k in attributes_to_remove:
279281
ava.pop(k, None)

src/satosa/util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
logger = logging.getLogger(__name__)
99

10+
def get_dict_defaults(d, *keys):
11+
for key in keys:
12+
d = d.get(key, d.get("", d.get("default", {})))
13+
return d
1014

1115
def rndstr(size=16, alphabet=""):
1216
"""

0 commit comments

Comments
 (0)