|
8 | 8 |
|
9 | 9 | from saml2 import saml
|
10 | 10 | from saml2 import xmlenc
|
11 |
| -from saml2.attribute_converter import from_local, get_local_name |
| 11 | +from saml2.attribute_converter import from_local, ac_factory |
| 12 | +from saml2.attribute_converter import get_local_name |
12 | 13 | from saml2.s_utils import assertion_factory
|
13 | 14 | from saml2.s_utils import factory
|
14 |
| -from saml2.s_utils import sid, MissingValue |
| 15 | +from saml2.s_utils import sid |
| 16 | +from saml2.s_utils import MissingValue |
15 | 17 | from saml2.saml import NAME_FORMAT_URI
|
16 |
| -from saml2.time_util import instant, in_a_while |
| 18 | +from saml2.time_util import instant |
| 19 | +from saml2.time_util import in_a_while |
17 | 20 |
|
18 | 21 | logger = logging.getLogger(__name__)
|
19 | 22 |
|
@@ -78,15 +81,22 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
|
78 | 81 | """
|
79 | 82 |
|
80 | 83 | def _match_attr_name(attr, ava):
|
81 |
| - |
82 |
| - local_name = get_local_name(acs, attr["name"], attr["name_format"]) |
83 |
| - if not local_name: |
84 |
| - try: |
85 |
| - local_name = attr["friendly_name"] |
86 |
| - except KeyError: |
87 |
| - pass |
| 84 | + local_name = None |
| 85 | + |
| 86 | + for a in ['name_format', 'friendly_name']: |
| 87 | + _val = attr.get(a) |
| 88 | + if _val: |
| 89 | + if a == 'name_format': |
| 90 | + local_name = get_local_name(acs, attr['name'], _val) |
| 91 | + else: |
| 92 | + local_name = _val |
| 93 | + break |
| 94 | + |
| 95 | + if local_name: |
| 96 | + _fn = _match(local_name, ava) |
| 97 | + else: |
| 98 | + _fn = None |
88 | 99 |
|
89 |
| - _fn = _match(local_name, ava) |
90 | 100 | if not _fn: # In the unlikely case that someone has provided us with
|
91 | 101 | # URIs as attribute names
|
92 | 102 | _fn = _match(attr["name"], ava)
|
@@ -117,8 +127,7 @@ def _apply_attr_value_restrictions(attr, res, must=False):
|
117 | 127 | if _fn:
|
118 | 128 | _apply_attr_value_restrictions(attr, res, True)
|
119 | 129 | elif fail_on_unfulfilled_requirements:
|
120 |
| - desc = "Required attribute missing: '%s' (%s)" % (attr["name"], |
121 |
| - _fn) |
| 130 | + desc = "Required attribute missing: '%s'" % (attr["name"]) |
122 | 131 | raise MissingValue(desc)
|
123 | 132 |
|
124 | 133 | if optional is None:
|
@@ -502,6 +511,9 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
|
502 | 511 |
|
503 | 512 | _ava = None
|
504 | 513 |
|
| 514 | + if not self.acs: # acs MUST have a value, fall back to default. |
| 515 | + self.acs = ac_factory() |
| 516 | + |
505 | 517 | _rest = self.get_entity_categories(sp_entity_id, mdstore, required)
|
506 | 518 | if _rest:
|
507 | 519 | _ava = filter_attribute_value_assertions(ava.copy(), _rest)
|
|
0 commit comments