Skip to content

Commit f6a4702

Browse files
committed
Ordered way to find a local name of an attribute.
1 parent 3360ee2 commit f6a4702

File tree

2 files changed

+42
-21
lines changed

2 files changed

+42
-21
lines changed

src/saml2/assertion.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88

99
from saml2 import saml
1010
from saml2 import xmlenc
11-
from saml2.attribute_converter import from_local, get_local_name
11+
from saml2.attribute_converter import from_local, ac_factory
12+
from saml2.attribute_converter import get_local_name
1213
from saml2.s_utils import assertion_factory
1314
from saml2.s_utils import factory
14-
from saml2.s_utils import sid, MissingValue
15+
from saml2.s_utils import sid
16+
from saml2.s_utils import MissingValue
1517
from saml2.saml import NAME_FORMAT_URI
16-
from saml2.time_util import instant, in_a_while
18+
from saml2.time_util import instant
19+
from saml2.time_util import in_a_while
1720

1821
logger = logging.getLogger(__name__)
1922

@@ -78,15 +81,22 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
7881
"""
7982

8083
def _match_attr_name(attr, ava):
81-
82-
local_name = get_local_name(acs, attr["name"], attr["name_format"])
83-
if not local_name:
84-
try:
85-
local_name = attr["friendly_name"]
86-
except KeyError:
87-
pass
84+
local_name = None
85+
86+
for a in ['name_format', 'friendly_name']:
87+
_val = attr.get(a)
88+
if _val:
89+
if a == 'name_format':
90+
local_name = get_local_name(acs, attr['name'], _val)
91+
else:
92+
local_name = _val
93+
break
94+
95+
if local_name:
96+
_fn = _match(local_name, ava)
97+
else:
98+
_fn = None
8899

89-
_fn = _match(local_name, ava)
90100
if not _fn: # In the unlikely case that someone has provided us with
91101
# URIs as attribute names
92102
_fn = _match(attr["name"], ava)
@@ -117,8 +127,7 @@ def _apply_attr_value_restrictions(attr, res, must=False):
117127
if _fn:
118128
_apply_attr_value_restrictions(attr, res, True)
119129
elif fail_on_unfulfilled_requirements:
120-
desc = "Required attribute missing: '%s' (%s)" % (attr["name"],
121-
_fn)
130+
desc = "Required attribute missing: '%s'" % (attr["name"])
122131
raise MissingValue(desc)
123132

124133
if optional is None:
@@ -502,6 +511,9 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
502511

503512
_ava = None
504513

514+
if not self.acs: # acs MUST have a value, fall back to default.
515+
self.acs = ac_factory()
516+
505517
_rest = self.get_entity_categories(sp_entity_id, mdstore, required)
506518
if _rest:
507519
_ava = filter_attribute_value_assertions(ava.copy(), _rest)

tests/test_20_assertion.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,17 @@ def test_filter_on_attributes_with_missing_optional_attribute():
130130
assert filter_on_attributes(ava, optional=[eptid], acs=ac_factory()) == {}
131131

132132

133+
def test_filter_on_attributes_with_missing_name_format():
134+
ava = {"eduPersonTargetedID": "[email protected]",
135+
"eduPersonAffiliation": "test",
136+
"extra": "foo"}
137+
eptid = to_dict(Attribute(friendly_name="eduPersonTargetedID",
138+
name="urn:myown:eptid",
139+
name_format=''), ONTS)
140+
ava = filter_on_attributes(ava, optional=[eptid], acs=ac_factory())
141+
assert ava['eduPersonTargetedID'] == "[email protected]"
142+
143+
133144
# ----------------------------------------------------------------------
134145

135146
def test_lifetime_1():
@@ -148,6 +159,7 @@ def test_lifetime_1():
148159
}}
149160

150161
r = Policy(conf)
162+
151163
assert r is not None
152164

153165
assert r.get_lifetime("urn:mace:umu.se:saml:roland:sp") == {"minutes": 5}
@@ -215,25 +227,22 @@ def test_ava_filter_2():
215227
"lifetime": {"minutes": 5},
216228
"attribute_restrictions": {
217229
"givenName": None,
218-
"surName": None,
230+
"sn": None,
219231
"mail": [".*@.*\.umu\.se"],
220232
}
221233
}}
222234

223235
policy = Policy(conf)
224236

225-
ava = {"givenName": "Derek",
226-
"surName": "Jeter",
227-
"mail": "[email protected]"}
237+
ava = {"givenName": "Derek", "sn": "Jeter", "mail": "[email protected]"}
228238

229239
# mail removed because it doesn't match the regular expression
230240
_ava = policy.filter(ava, 'urn:mace:umu.se:saml:roland:sp', None, [mail],
231241
[gn, sn])
232242

233-
assert _eq(sorted(list(_ava.keys())), ["givenName", "surName"])
243+
assert _eq(sorted(list(_ava.keys())), ["givenName", 'sn'])
234244

235-
ava = {"givenName": "Derek",
236-
"surName": "Jeter"}
245+
ava = {"givenName": "Derek", "sn": "Jeter"}
237246

238247
# it wasn't there to begin with
239248
try:
@@ -746,7 +755,7 @@ def test_req_opt():
746755
is_required="false"), ONTS)]
747756

748757
policy = Policy()
749-
ava = {'givenname': 'Roland', 'surname': 'Hedberg',
758+
ava = {'givenname': 'Roland', 'sn': 'Hedberg',
750759
'uid': 'rohe0002', 'edupersonaffiliation': 'staff'}
751760

752761
sp_entity_id = "urn:mace:example.com:saml:curt:sp"

0 commit comments

Comments
 (0)