Skip to content

Commit 7a7b02d

Browse files
author
Rebecka Gulliksson
committed
Match the attribute name of optional attributes in the same way as for required attributes.
1 parent 344ba4a commit 7a7b02d

File tree

2 files changed

+187
-174
lines changed

2 files changed

+187
-174
lines changed

src/saml2/assertion.py

Lines changed: 68 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from saml2.s_utils import factory
1717
from saml2.s_utils import assertion_factory
1818

19-
2019
logger = logging.getLogger(__name__)
2120

2221

@@ -78,25 +77,24 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
7877
are missing fail or fail not depending on this parameter.
7978
:return: The modified attribute value assertion
8079
"""
80+
81+
def _attr_name(attr):
82+
"""Get the friendly name of an attribute name"""
83+
try:
84+
return attr["friendly_name"]
85+
except KeyError:
86+
return get_local_name(acs, attr["name"], attr["name_format"])
87+
8188
res = {}
8289

8390
if required is None:
8491
required = []
8592

86-
nform = "friendly_name"
8793
for attr in required:
88-
try:
89-
_name = attr[nform]
90-
except KeyError:
91-
if nform == "friendly_name":
92-
_name = get_local_name(acs, attr["name"],
93-
attr["name_format"])
94-
else:
95-
continue
96-
94+
_name = _attr_name(attr)
9795
_fn = _match(_name, ava)
9896
if not _fn: # In the unlikely case that someone has provided us
99-
# with URIs as attribute names
97+
# with URIs as attribute names
10098
_fn = _match(attr["name"], ava)
10199

102100
if _fn:
@@ -115,18 +113,17 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
115113
optional = []
116114

117115
for attr in optional:
118-
for nform in ["friendly_name", "name"]:
119-
if nform in attr:
120-
_fn = _match(attr[nform], ava)
121-
if _fn:
122-
try:
123-
values = [av["text"] for av in attr["attribute_value"]]
124-
except KeyError:
125-
values = []
126-
try:
127-
res[_fn].extend(_filter_values(ava[_fn], values))
128-
except KeyError:
129-
res[_fn] = _filter_values(ava[_fn], values)
116+
_name = _attr_name(attr)
117+
_fn = _match(_name, ava)
118+
if _fn:
119+
try:
120+
values = [av["text"] for av in attr["attribute_value"]]
121+
except KeyError:
122+
values = []
123+
try:
124+
res[_fn].extend(_filter_values(ava[_fn], values))
125+
except KeyError:
126+
res[_fn] = _filter_values(ava[_fn], values)
130127

131128
return res
132129

@@ -154,8 +151,8 @@ def filter_on_demands(ava, required=None, optional=None):
154151
for val in vals:
155152
if val not in ava[lava[attr]]:
156153
raise MissingValue(
157-
"Required attribute value missing: %s,%s" % (attr,
158-
val))
154+
"Required attribute value missing: %s,%s" % (attr,
155+
val))
159156
else:
160157
raise MissingValue("Required attribute missing: %s" % (attr,))
161158

@@ -334,7 +331,7 @@ def compile(self, restrictions):
334331
ecs = []
335332
for cat in items:
336333
_mod = importlib.import_module(
337-
"saml2.entity_category.%s" % cat)
334+
"saml2.entity_category.%s" % cat)
338335
_ec = {}
339336
for key, items in _mod.RELEASE.items():
340337
_ec[key] = [k.lower() for k in items]
@@ -488,8 +485,8 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
488485
if required or optional:
489486
logger.debug("required: %s, optional: %s", required, optional)
490487
_ava = filter_on_attributes(
491-
ava.copy(), required, optional, self.acs,
492-
self.get_fail_on_missing_requested(sp_entity_id))
488+
ava.copy(), required, optional, self.acs,
489+
self.get_fail_on_missing_requested(sp_entity_id))
493490

494491
_rest = self.get_entity_categories(sp_entity_id, mdstore)
495492
if _rest:
@@ -539,9 +536,9 @@ def conditions(self, sp_entity_id):
539536
# How long might depend on who's getting it
540537
not_on_or_after=self.not_on_or_after(sp_entity_id),
541538
audience_restriction=[factory(
542-
saml.AudienceRestriction,
543-
audience=[factory(saml.Audience,
544-
text=sp_entity_id)])])
539+
saml.AudienceRestriction,
540+
audience=[factory(saml.Audience,
541+
text=sp_entity_id)])])
545542

546543
def get_sign(self, sp_entity_id):
547544
"""
@@ -571,7 +568,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None):
571568
return factory(saml.AuthnContext,
572569
authn_context_class_ref=cntx_class,
573570
authenticating_authority=factory(
574-
saml.AuthenticatingAuthority, text=authn_auth))
571+
saml.AuthenticatingAuthority, text=authn_auth))
575572
else:
576573
return factory(saml.AuthnContext,
577574
authn_context_class_ref=cntx_class)
@@ -587,7 +584,7 @@ def _authn_context_decl(decl, authn_auth=None):
587584
return factory(saml.AuthnContext,
588585
authn_context_decl=decl,
589586
authenticating_authority=factory(
590-
saml.AuthenticatingAuthority, text=authn_auth))
587+
saml.AuthenticatingAuthority, text=authn_auth))
591588

592589

593590
def _authn_context_decl_ref(decl_ref, authn_auth=None):
@@ -600,7 +597,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None):
600597
return factory(saml.AuthnContext,
601598
authn_context_decl_ref=decl_ref,
602599
authenticating_authority=factory(
603-
saml.AuthenticatingAuthority, text=authn_auth))
600+
saml.AuthenticatingAuthority, text=authn_auth))
604601

605602

606603
def authn_statement(authn_class=None, authn_auth=None,
@@ -626,29 +623,29 @@ def authn_statement(authn_class=None, authn_auth=None,
626623

627624
if authn_class:
628625
res = factory(
629-
saml.AuthnStatement,
630-
authn_instant=_instant,
631-
session_index=sid(),
632-
authn_context=_authn_context_class_ref(
633-
authn_class, authn_auth))
626+
saml.AuthnStatement,
627+
authn_instant=_instant,
628+
session_index=sid(),
629+
authn_context=_authn_context_class_ref(
630+
authn_class, authn_auth))
634631
elif authn_decl:
635632
res = factory(
636-
saml.AuthnStatement,
637-
authn_instant=_instant,
638-
session_index=sid(),
639-
authn_context=_authn_context_decl(authn_decl, authn_auth))
633+
saml.AuthnStatement,
634+
authn_instant=_instant,
635+
session_index=sid(),
636+
authn_context=_authn_context_decl(authn_decl, authn_auth))
640637
elif authn_decl_ref:
641638
res = factory(
642-
saml.AuthnStatement,
643-
authn_instant=_instant,
644-
session_index=sid(),
645-
authn_context=_authn_context_decl_ref(authn_decl_ref,
646-
authn_auth))
639+
saml.AuthnStatement,
640+
authn_instant=_instant,
641+
session_index=sid(),
642+
authn_context=_authn_context_decl_ref(authn_decl_ref,
643+
authn_auth))
647644
else:
648645
res = factory(
649-
saml.AuthnStatement,
650-
authn_instant=_instant,
651-
session_index=sid())
646+
saml.AuthnStatement,
647+
authn_instant=_instant,
648+
session_index=sid())
652649

653650
if subject_locality:
654651
res.subject_locality = saml.SubjectLocality(text=subject_locality)
@@ -698,7 +695,7 @@ def construct(self, sp_entity_id, in_response_to, consumer_url,
698695
_name_format = NAME_FORMAT_URI
699696

700697
attr_statement = saml.AttributeStatement(attribute=from_local(
701-
attrconvs, self, _name_format))
698+
attrconvs, self, _name_format))
702699

703700
if encrypt == "attributes":
704701
for attr in attr_statement.attribute:
@@ -725,33 +722,33 @@ def construct(self, sp_entity_id, in_response_to, consumer_url,
725722

726723
if not add_subject:
727724
_ass = assertion_factory(
728-
issuer=issuer,
729-
conditions=conds,
730-
subject=None
725+
issuer=issuer,
726+
conditions=conds,
727+
subject=None
731728
)
732729
else:
733730
_ass = assertion_factory(
734-
issuer=issuer,
735-
conditions=conds,
736-
subject=factory(
737-
saml.Subject,
738-
name_id=name_id,
739-
subject_confirmation=[factory(
740-
saml.SubjectConfirmation,
741-
method=saml.SCM_BEARER,
742-
subject_confirmation_data=factory(
743-
saml.SubjectConfirmationData,
744-
in_response_to=in_response_to,
745-
recipient=consumer_url,
746-
not_on_or_after=policy.not_on_or_after(sp_entity_id)))]
747-
),
731+
issuer=issuer,
732+
conditions=conds,
733+
subject=factory(
734+
saml.Subject,
735+
name_id=name_id,
736+
subject_confirmation=[factory(
737+
saml.SubjectConfirmation,
738+
method=saml.SCM_BEARER,
739+
subject_confirmation_data=factory(
740+
saml.SubjectConfirmationData,
741+
in_response_to=in_response_to,
742+
recipient=consumer_url,
743+
not_on_or_after=policy.not_on_or_after(sp_entity_id)))]
744+
),
748745
)
749746

750747
if _authn_statement:
751748
_ass.authn_statement = [_authn_statement]
752749

753750
if not attr_statement.empty():
754-
_ass.attribute_statement=[attr_statement]
751+
_ass.attribute_statement = [attr_statement]
755752

756753
return _ass
757754

0 commit comments

Comments
 (0)