Skip to content

Commit 7044da3

Browse files
author
Rebecka Gulliksson
committed
Filter optional attributes in the exact same way as required attributes.
1 parent 7a7b02d commit 7044da3

File tree

2 files changed

+49
-28
lines changed

2 files changed

+49
-28
lines changed

src/saml2/assertion.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -78,52 +78,53 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
7878
:return: The modified attribute value assertion
7979
"""
8080

81-
def _attr_name(attr):
82-
"""Get the friendly name of an attribute name"""
81+
def _match_attr_name(attr, ava):
8382
try:
84-
return attr["friendly_name"]
83+
friendly_name = attr["friendly_name"]
8584
except KeyError:
86-
return get_local_name(acs, attr["name"], attr["name_format"])
85+
friendly_name = get_local_name(acs, attr["name"], attr["name_format"])
86+
87+
_fn = _match(friendly_name, ava)
88+
if not _fn: # In the unlikely case that someone has provided us with URIs as attribute names
89+
_fn = _match(attr["name"], ava)
90+
91+
return _fn
92+
93+
def _apply_attr_value_restrictions(attr, res, must=False):
94+
try:
95+
values = [av["text"] for av in attr["attribute_value"]]
96+
except KeyError:
97+
values = []
98+
99+
try:
100+
res[_fn].extend(_filter_values(ava[_fn], values))
101+
except KeyError:
102+
res[_fn] = _filter_values(ava[_fn], values)
103+
104+
return _filter_values(ava[_fn], values, must)
87105

88106
res = {}
89107

90108
if required is None:
91109
required = []
92110

93111
for attr in required:
94-
_name = _attr_name(attr)
95-
_fn = _match(_name, ava)
96-
if not _fn: # In the unlikely case that someone has provided us
97-
# with URIs as attribute names
98-
_fn = _match(attr["name"], ava)
112+
_fn = _match_attr_name(attr, ava)
99113

100114
if _fn:
101-
try:
102-
values = [av["text"] for av in attr["attribute_value"]]
103-
except KeyError:
104-
values = []
105-
res[_fn] = _filter_values(ava[_fn], values, True)
106-
continue
115+
_apply_attr_value_restrictions(attr, res, True)
107116
elif fail_on_unfulfilled_requirements:
108117
desc = "Required attribute missing: '%s' (%s)" % (attr["name"],
109-
_name)
118+
_fn)
110119
raise MissingValue(desc)
111120

112121
if optional is None:
113122
optional = []
114123

115124
for attr in optional:
116-
_name = _attr_name(attr)
117-
_fn = _match(_name, ava)
125+
_fn = _match_attr_name(attr, ava)
118126
if _fn:
119-
try:
120-
values = [av["text"] for av in attr["attribute_value"]]
121-
except KeyError:
122-
values = []
123-
try:
124-
res[_fn].extend(_filter_values(ava[_fn], values))
125-
except KeyError:
126-
res[_fn] = _filter_values(ava[_fn], values)
127+
_apply_attr_value_restrictions(attr, res, False)
127128

128129
return res
129130

tests/test_20_assertion.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# coding=utf-8
2+
import pytest
3+
24
from saml2.authn_context import pword
35
from saml2.mdie import to_dict
46
from saml2 import md, assertion
@@ -81,16 +83,34 @@ def test_filter_on_attributes_1():
8183
def test_filter_on_attributes_without_friendly_name():
8284
ava = {"eduPersonTargetedID": "[email protected]", "eduPersonAffiliation": "test",
8385
"extra": "foo"}
84-
eptid = to_dict(Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10", name_format=NAME_FORMAT_URI), ONTS)
86+
eptid = to_dict(
87+
Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10", name_format=NAME_FORMAT_URI), ONTS)
8588
ep_affiliation = to_dict(
86-
Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.1", name_format=NAME_FORMAT_URI), ONTS)
89+
Attribute(name="urn:oid:1.3.6.1.4.1.5923.1.1.1.1", name_format=NAME_FORMAT_URI), ONTS)
8790

8891
restricted_ava = filter_on_attributes(ava, required=[eptid], optional=[ep_affiliation],
8992
acs=ac_factory())
9093
assert restricted_ava == {"eduPersonTargetedID": "[email protected]",
9194
"eduPersonAffiliation": "test"}
9295

9396

97+
def test_filter_on_attributes_with_missing_required_attribute():
98+
ava = {"extra": "foo"}
99+
eptid = to_dict(Attribute(
100+
friendly_name="eduPersonTargetedID", name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10",
101+
name_format=NAME_FORMAT_URI), ONTS)
102+
with pytest.raises(MissingValue):
103+
filter_on_attributes(ava, required=[eptid])
104+
105+
106+
def test_filter_on_attributes_with_missing_optional_attribute():
107+
ava = {"extra": "foo"}
108+
eptid = to_dict(Attribute(
109+
friendly_name="eduPersonTargetedID", name="urn:oid:1.3.6.1.4.1.5923.1.1.1.10",
110+
name_format=NAME_FORMAT_URI), ONTS)
111+
assert filter_on_attributes(ava, optional=[eptid]) == {}
112+
113+
94114
# ----------------------------------------------------------------------
95115

96116
def test_lifetime_1():

0 commit comments

Comments
 (0)