Skip to content

Commit f5622e1

Browse files
committed
enable defaults for custom attribute release using '' or 'default' key
1 parent e7e6f5a commit f5622e1

File tree

3 files changed

+18
-2
lines changed

3 files changed

+18
-2
lines changed

doc/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,18 @@ config:
242242
idp-entity-id1
243243
sp-entity-id1:
244244
exclude: ["givenName"]
245+
246+
247+
The custom_attribute_release mechanism supports defaults based on idp and sp entity Id by specifying "" or "default"
248+
as the key in the dict. For instance in order to exclude givenName for any sp or idp do this:
249+
250+
```yaml
251+
config:
252+
config: [...]
253+
custom_attribute_release:
254+
"default":
255+
"":
256+
exclude: ["givenName"]
245257

246258

247259
#### Backend

src/satosa/frontends/saml2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..response import Response
2222
from ..response import ServiceError
2323
from ..saml_util import make_saml_response
24+
from ..util import get_dict_defaults
2425

2526
logger = logging.getLogger(__name__)
2627

@@ -272,8 +273,7 @@ def _handle_authn_response(self, context, internal_response, idp):
272273
auth_info["class_ref"] = internal_response.auth_info.auth_class_ref
273274

274275
if self.custom_attribute_release:
275-
custom_release_per_idp = self.custom_attribute_release.get(internal_response.auth_info.issuer, {})
276-
custom_release = custom_release_per_idp.get(resp_args["sp_entity_id"], {})
276+
custom_release = get_dict_defaults(self.custom_attribute_release, internal_response.auth_info.issuer, resp_args["sp_entity_id"])
277277
attributes_to_remove = custom_release.get("exclude", [])
278278
for k in attributes_to_remove:
279279
ava.pop(k, None)

src/satosa/util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
logger = logging.getLogger(__name__)
99

10+
def get_dict_defaults(d, *keys):
11+
for key in keys:
12+
d = d.get(key, d.get("", d.get("default", {})))
13+
return d
1014

1115
def rndstr(size=16, alphabet=""):
1216
"""

0 commit comments

Comments
 (0)