Skip to content

Commit 7b41b46

Browse files
author
Roland Hedberg
committed
Refactored class methods
1 parent 6b1bc50 commit 7b41b46

File tree

3 files changed

+97
-109
lines changed

3 files changed

+97
-109
lines changed

example/sp-wsgi/sp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def do(self, response, binding, relay_state="", mtype="response"):
362362

363363
def verify_attributes(self, ava):
364364
logger.info("SP: %s" % self.sp.config.entityid)
365-
rest = POLICY.get_entity_categories_restriction(
365+
rest = POLICY.get_entity_categories(
366366
self.sp.config.entityid, self.sp.metadata)
367367

368368
akeys = [k.lower() for k in ava.keys()]

src/saml2/assertion.py

Lines changed: 94 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,39 @@ def restriction_from_attribute_spec(attributes):
270270
return restr
271271

272272

273+
def post_entity_categories(maps, **kwargs):
274+
restrictions = {}
275+
if kwargs["mds"]:
276+
try:
277+
ecs = kwargs["mds"].entity_categories(kwargs["sp_entity_id"])
278+
except KeyError:
279+
for ec_map in maps:
280+
for attr in ec_map[""]:
281+
restrictions[attr] = None
282+
else:
283+
for ec_map in maps:
284+
for key, val in ec_map.items():
285+
if key == "": # always released
286+
attrs = val
287+
elif isinstance(key, tuple):
288+
attrs = val
289+
for _key in key:
290+
try:
291+
assert _key in ecs
292+
except AssertionError:
293+
attrs = []
294+
break
295+
elif key in ecs:
296+
attrs = val
297+
else:
298+
attrs = []
299+
300+
for attr in attrs:
301+
restrictions[attr] = None
302+
303+
return restrictions
304+
305+
273306
class Policy(object):
274307
""" handles restrictions on assertions """
275308

@@ -329,85 +362,70 @@ def compile(self, restrictions):
329362
logger.debug("policy restrictions: %s" % self._restrictions)
330363

331364
return self._restrictions
332-
365+
366+
def get(self, attribute, sp_entity_id, default=None, post_func=None,
367+
**kwargs):
368+
"""
369+
370+
:param attribute:
371+
:param sp_entity_id:
372+
:param default:
373+
:param post_func:
374+
:return:
375+
"""
376+
if not self._restrictions:
377+
return default
378+
379+
try:
380+
try:
381+
val = self._restrictions[sp_entity_id][attribute]
382+
except KeyError:
383+
try:
384+
val = self._restrictions["default"][attribute]
385+
except KeyError:
386+
val = None
387+
except KeyError:
388+
val = None
389+
390+
if val is None:
391+
return default
392+
elif post_func:
393+
return post_func(val, sp_entity_id=sp_entity_id, **kwargs)
394+
else:
395+
return val
396+
333397
def get_nameid_format(self, sp_entity_id):
334398
""" Get the NameIDFormat to used for the entity id
335399
:param: The SP entity ID
336400
:retur: The format
337401
"""
338-
try:
339-
form = self._restrictions[sp_entity_id]["nameid_format"]
340-
except KeyError:
341-
try:
342-
form = self._restrictions["default"]["nameid_format"]
343-
except KeyError:
344-
form = saml.NAMEID_FORMAT_TRANSIENT
345-
346-
return form
347-
402+
return self.get("nameid_format", sp_entity_id,
403+
saml.NAMEID_FORMAT_TRANSIENT)
404+
348405
def get_name_form(self, sp_entity_id):
349406
""" Get the NameFormat to used for the entity id
350407
:param: The SP entity ID
351408
:retur: The format
352409
"""
353-
form = NAME_FORMAT_URI
354-
355-
try:
356-
form = self._restrictions[sp_entity_id]["name_form"]
357-
except TypeError:
358-
pass
359-
except KeyError:
360-
try:
361-
form = self._restrictions["default"]["name_form"]
362-
except KeyError:
363-
pass
364-
365-
return form
366-
410+
411+
return self.get("name_format", sp_entity_id, NAME_FORMAT_URI)
412+
367413
def get_lifetime(self, sp_entity_id):
368414
""" The lifetime of the assertion
369415
:param sp_entity_id: The SP entity ID
370416
:param: lifetime as a dictionary
371417
"""
372418
# default is a hour
373-
spec = {"hours": 1}
374-
if not self._restrictions:
375-
return spec
376-
377-
try:
378-
spec = self._restrictions[sp_entity_id]["lifetime"]
379-
except KeyError:
380-
try:
381-
spec = self._restrictions["default"]["lifetime"]
382-
except KeyError:
383-
pass
384-
385-
return spec
386-
387-
def get_attribute_restriction(self, sp_entity_id):
419+
return self.get("lifetime", sp_entity_id, {"hours": 1})
420+
421+
def get_attribute_restrictions(self, sp_entity_id):
388422
""" Return the attribute restriction for SP that want the information
389423
390424
:param sp_entity_id: The SP entity ID
391425
:return: The restrictions
392426
"""
393-
394-
if not self._restrictions:
395-
return None
396-
397-
try:
398-
try:
399-
restrictions = self._restrictions[sp_entity_id][
400-
"attribute_restrictions"]
401-
except KeyError:
402-
try:
403-
restrictions = self._restrictions["default"][
404-
"attribute_restrictions"]
405-
except KeyError:
406-
restrictions = None
407-
except KeyError:
408-
restrictions = None
409-
410-
return restrictions
427+
428+
return self.get("attribute_restrictions", sp_entity_id)
411429

412430
def entity_category_attributes(self, ec):
413431
if not self._restrictions:
@@ -421,59 +439,18 @@ def entity_category_attributes(self, ec):
421439
pass
422440
return []
423441

424-
def get_entity_categories_restriction(self, sp_entity_id, mds):
442+
def get_entity_categories(self, sp_entity_id, mds):
425443
"""
426444
427445
:param sp_entity_id:
428446
:param mds: MetadataStore instance
429-
:return: A dictionary with restrictionsmetat
447+
:return: A dictionary with restrictions
430448
"""
431-
if not self._restrictions:
432-
return None
433-
434-
restrictions = {}
435-
ec_maps = []
436-
try:
437-
try:
438-
ec_maps = self._restrictions[sp_entity_id]["entity_categories"]
439-
except KeyError:
440-
try:
441-
ec_maps = self._restrictions["default"]["entity_categories"]
442-
except KeyError:
443-
pass
444-
except KeyError:
445-
pass
446449

447-
if ec_maps:
448-
if mds:
449-
try:
450-
ecs = mds.entity_categories(sp_entity_id)
451-
except KeyError:
452-
for ec_map in ec_maps:
453-
for attr in ec_map[""]:
454-
restrictions[attr] = None
455-
else:
456-
for ec_map in ec_maps:
457-
for key, val in ec_map.items():
458-
if key == "": # always released
459-
attrs = val
460-
elif isinstance(key, tuple):
461-
attrs = val
462-
for _key in key:
463-
try:
464-
assert _key in ecs
465-
except AssertionError:
466-
attrs = []
467-
break
468-
elif key in ecs:
469-
attrs = val
470-
else:
471-
attrs = []
450+
kwargs = {"mds": mds}
472451

473-
for attr in attrs:
474-
restrictions[attr] = None
475-
476-
return restrictions
452+
return self.get("entity_categories", sp_entity_id, default={},
453+
post_func=post_entity_categories, **kwargs)
477454

478455
def not_on_or_after(self, sp_entity_id):
479456
""" When the assertion stops being valid, should not be
@@ -500,10 +477,9 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
500477
:return: A possibly modified AVA
501478
"""
502479

503-
_rest = self.get_attribute_restriction(sp_entity_id)
480+
_rest = self.get_attribute_restrictions(sp_entity_id)
504481
if _rest is None:
505-
_rest = self.get_entity_categories_restriction(sp_entity_id,
506-
mdstore)
482+
_rest = self.get_entity_categories(sp_entity_id, mdstore)
507483
logger.debug("filter based on: %s" % _rest)
508484
ava = filter_attribute_value_assertions(ava, _rest)
509485

@@ -543,6 +519,17 @@ def conditions(self, sp_entity_id):
543519
audience=[factory(saml.Audience,
544520
text=sp_entity_id)])])
545521

522+
def get_sign(self, sp_entity_id):
523+
"""
524+
Possible choices
525+
"sign": ["response", "assertion", "on_demand"]
526+
527+
:param sp_entity_id:
528+
:return:
529+
"""
530+
531+
return self.get("sign", sp_entity_id, [])
532+
546533

547534
class EntityCategories(object):
548535
pass

src/saml2/extension/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
# metadata extensions mainly
22
__author__ = 'rolandh'
3-
__all__ = ["dri", "mdrpi", "mdui", "shibmd", "idpdisc"]
3+
__all__ = ["dri", "mdrpi", "mdui", "shibmd", "idpdisc", 'algsupport',
4+
'mdattr', 'ui']

0 commit comments

Comments
 (0)