Skip to content

Commit 32df0d1

Browse files
author
Roland Hedberg
committed
Merge branch 'master' of github.com:rohe/pysaml2
2 parents 8f09d2f + 1384687 commit 32df0d1

File tree

4 files changed

+258
-206
lines changed

4 files changed

+258
-206
lines changed

src/saml2/assertion.py

Lines changed: 79 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from saml2.s_utils import factory
1717
from saml2.s_utils import assertion_factory
1818

19-
2019
logger = logging.getLogger(__name__)
2120

2221

@@ -78,55 +77,54 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
7877
are missing fail or fail not depending on this parameter.
7978
:return: The modified attribute value assertion
8079
"""
80+
81+
def _match_attr_name(attr, ava):
82+
try:
83+
friendly_name = attr["friendly_name"]
84+
except KeyError:
85+
friendly_name = get_local_name(acs, attr["name"], attr["name_format"])
86+
87+
_fn = _match(friendly_name, ava)
88+
if not _fn: # In the unlikely case that someone has provided us with URIs as attribute names
89+
_fn = _match(attr["name"], ava)
90+
91+
return _fn
92+
93+
def _apply_attr_value_restrictions(attr, res, must=False):
94+
try:
95+
values = [av["text"] for av in attr["attribute_value"]]
96+
except KeyError:
97+
values = []
98+
99+
try:
100+
res[_fn].extend(_filter_values(ava[_fn], values))
101+
except KeyError:
102+
res[_fn] = _filter_values(ava[_fn], values)
103+
104+
return _filter_values(ava[_fn], values, must)
105+
81106
res = {}
82107

83108
if required is None:
84109
required = []
85110

86-
nform = "friendly_name"
87111
for attr in required:
88-
try:
89-
_name = attr[nform]
90-
except KeyError:
91-
if nform == "friendly_name":
92-
_name = get_local_name(acs, attr["name"],
93-
attr["name_format"])
94-
else:
95-
continue
96-
97-
_fn = _match(_name, ava)
98-
if not _fn: # In the unlikely case that someone has provided us
99-
# with URIs as attribute names
100-
_fn = _match(attr["name"], ava)
112+
_fn = _match_attr_name(attr, ava)
101113

102114
if _fn:
103-
try:
104-
values = [av["text"] for av in attr["attribute_value"]]
105-
except KeyError:
106-
values = []
107-
res[_fn] = _filter_values(ava[_fn], values, True)
108-
continue
115+
_apply_attr_value_restrictions(attr, res, True)
109116
elif fail_on_unfulfilled_requirements:
110117
desc = "Required attribute missing: '%s' (%s)" % (attr["name"],
111-
_name)
118+
_fn)
112119
raise MissingValue(desc)
113120

114121
if optional is None:
115122
optional = []
116123

117124
for attr in optional:
118-
for nform in ["friendly_name", "name"]:
119-
if nform in attr:
120-
_fn = _match(attr[nform], ava)
121-
if _fn:
122-
try:
123-
values = [av["text"] for av in attr["attribute_value"]]
124-
except KeyError:
125-
values = []
126-
try:
127-
res[_fn].extend(_filter_values(ava[_fn], values))
128-
except KeyError:
129-
res[_fn] = _filter_values(ava[_fn], values)
125+
_fn = _match_attr_name(attr, ava)
126+
if _fn:
127+
_apply_attr_value_restrictions(attr, res, False)
130128

131129
return res
132130

@@ -154,8 +152,8 @@ def filter_on_demands(ava, required=None, optional=None):
154152
for val in vals:
155153
if val not in ava[lava[attr]]:
156154
raise MissingValue(
157-
"Required attribute value missing: %s,%s" % (attr,
158-
val))
155+
"Required attribute value missing: %s,%s" % (attr,
156+
val))
159157
else:
160158
raise MissingValue("Required attribute missing: %s" % (attr,))
161159

@@ -334,7 +332,7 @@ def compile(self, restrictions):
334332
ecs = []
335333
for cat in items:
336334
_mod = importlib.import_module(
337-
"saml2.entity_category.%s" % cat)
335+
"saml2.entity_category.%s" % cat)
338336
_ec = {}
339337
for key, items in _mod.RELEASE.items():
340338
_ec[key] = [k.lower() for k in items]
@@ -488,8 +486,8 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
488486
if required or optional:
489487
logger.debug("required: %s, optional: %s", required, optional)
490488
_ava = filter_on_attributes(
491-
ava.copy(), required, optional, self.acs,
492-
self.get_fail_on_missing_requested(sp_entity_id))
489+
ava.copy(), required, optional, self.acs,
490+
self.get_fail_on_missing_requested(sp_entity_id))
493491

494492
_rest = self.get_entity_categories(sp_entity_id, mdstore)
495493
if _rest:
@@ -539,9 +537,9 @@ def conditions(self, sp_entity_id):
539537
# How long might depend on who's getting it
540538
not_on_or_after=self.not_on_or_after(sp_entity_id),
541539
audience_restriction=[factory(
542-
saml.AudienceRestriction,
543-
audience=[factory(saml.Audience,
544-
text=sp_entity_id)])])
540+
saml.AudienceRestriction,
541+
audience=[factory(saml.Audience,
542+
text=sp_entity_id)])])
545543

546544
def get_sign(self, sp_entity_id):
547545
"""
@@ -571,7 +569,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None):
571569
return factory(saml.AuthnContext,
572570
authn_context_class_ref=cntx_class,
573571
authenticating_authority=factory(
574-
saml.AuthenticatingAuthority, text=authn_auth))
572+
saml.AuthenticatingAuthority, text=authn_auth))
575573
else:
576574
return factory(saml.AuthnContext,
577575
authn_context_class_ref=cntx_class)
@@ -587,7 +585,7 @@ def _authn_context_decl(decl, authn_auth=None):
587585
return factory(saml.AuthnContext,
588586
authn_context_decl=decl,
589587
authenticating_authority=factory(
590-
saml.AuthenticatingAuthority, text=authn_auth))
588+
saml.AuthenticatingAuthority, text=authn_auth))
591589

592590

593591
def _authn_context_decl_ref(decl_ref, authn_auth=None):
@@ -600,7 +598,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None):
600598
return factory(saml.AuthnContext,
601599
authn_context_decl_ref=decl_ref,
602600
authenticating_authority=factory(
603-
saml.AuthenticatingAuthority, text=authn_auth))
601+
saml.AuthenticatingAuthority, text=authn_auth))
604602

605603

606604
def authn_statement(authn_class=None, authn_auth=None,
@@ -626,29 +624,29 @@ def authn_statement(authn_class=None, authn_auth=None,
626624

627625
if authn_class:
628626
res = factory(
629-
saml.AuthnStatement,
630-
authn_instant=_instant,
631-
session_index=sid(),
632-
authn_context=_authn_context_class_ref(
633-
authn_class, authn_auth))
627+
saml.AuthnStatement,
628+
authn_instant=_instant,
629+
session_index=sid(),
630+
authn_context=_authn_context_class_ref(
631+
authn_class, authn_auth))
634632
elif authn_decl:
635633
res = factory(
636-
saml.AuthnStatement,
637-
authn_instant=_instant,
638-
session_index=sid(),
639-
authn_context=_authn_context_decl(authn_decl, authn_auth))
634+
saml.AuthnStatement,
635+
authn_instant=_instant,
636+
session_index=sid(),
637+
authn_context=_authn_context_decl(authn_decl, authn_auth))
640638
elif authn_decl_ref:
641639
res = factory(
642-
saml.AuthnStatement,
643-
authn_instant=_instant,
644-
session_index=sid(),
645-
authn_context=_authn_context_decl_ref(authn_decl_ref,
646-
authn_auth))
640+
saml.AuthnStatement,
641+
authn_instant=_instant,
642+
session_index=sid(),
643+
authn_context=_authn_context_decl_ref(authn_decl_ref,
644+
authn_auth))
647645
else:
648646
res = factory(
649-
saml.AuthnStatement,
650-
authn_instant=_instant,
651-
session_index=sid())
647+
saml.AuthnStatement,
648+
authn_instant=_instant,
649+
session_index=sid())
652650

653651
if subject_locality:
654652
res.subject_locality = saml.SubjectLocality(text=subject_locality)
@@ -698,7 +696,7 @@ def construct(self, sp_entity_id, in_response_to, consumer_url,
698696
_name_format = NAME_FORMAT_URI
699697

700698
attr_statement = saml.AttributeStatement(attribute=from_local(
701-
attrconvs, self, _name_format))
699+
attrconvs, self, _name_format))
702700

703701
if encrypt == "attributes":
704702
for attr in attr_statement.attribute:
@@ -725,33 +723,33 @@ def construct(self, sp_entity_id, in_response_to, consumer_url,
725723

726724
if not add_subject:
727725
_ass = assertion_factory(
728-
issuer=issuer,
729-
conditions=conds,
730-
subject=None
726+
issuer=issuer,
727+
conditions=conds,
728+
subject=None
731729
)
732730
else:
733731
_ass = assertion_factory(
734-
issuer=issuer,
735-
conditions=conds,
736-
subject=factory(
737-
saml.Subject,
738-
name_id=name_id,
739-
subject_confirmation=[factory(
740-
saml.SubjectConfirmation,
741-
method=saml.SCM_BEARER,
742-
subject_confirmation_data=factory(
743-
saml.SubjectConfirmationData,
744-
in_response_to=in_response_to,
745-
recipient=consumer_url,
746-
not_on_or_after=policy.not_on_or_after(sp_entity_id)))]
747-
),
732+
issuer=issuer,
733+
conditions=conds,
734+
subject=factory(
735+
saml.Subject,
736+
name_id=name_id,
737+
subject_confirmation=[factory(
738+
saml.SubjectConfirmation,
739+
method=saml.SCM_BEARER,
740+
subject_confirmation_data=factory(
741+
saml.SubjectConfirmationData,
742+
in_response_to=in_response_to,
743+
recipient=consumer_url,
744+
not_on_or_after=policy.not_on_or_after(sp_entity_id)))]
745+
),
748746
)
749747

750748
if _authn_statement:
751749
_ass.authn_statement = [_authn_statement]
752750

753751
if not attr_statement.empty():
754-
_ass.attribute_statement=[attr_statement]
752+
_ass.attribute_statement = [attr_statement]
755753

756754
return _ass
757755

src/saml2/attribute_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def adjust(self):
284284
if self._fro is None and self._to is not None:
285285
self._fro = dict(
286286
[(value.lower(), key) for key, value in self._to.items()])
287-
if self._to is None and self.fro is not None:
287+
if self._to is None and self._fro is not None:
288288
self._to = dict(
289289
[(value.lower(), key) for key, value in self._fro.items()])
290290

0 commit comments

Comments
 (0)