Skip to content

Commit a46e175

Browse files
author
Roland Hedberg
committed
Added logging to module
1 parent 9b22450 commit a46e175

File tree

1 file changed

+59
-26
lines changed

1 file changed

+59
-26
lines changed

src/saml2/attribute_converter.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
from saml2 import saml, extension_elements_to_elements, SAMLError
2424
from saml2.saml import NAME_FORMAT_URI, NAME_FORMAT_UNSPECIFIED
2525

26+
import logging
27+
logger = logging.getLogger(__name__)
28+
2629

2730
class UnknownNameFormat(SAMLError):
2831
pass
@@ -98,26 +101,30 @@ def ac_factory_II(path):
98101
return ac_factory(path)
99102

100103

101-
def ava_fro(acs, statement):
102-
""" Translates attributes according to their name_formats into the local
103-
names.
104-
105-
:param acs: AttributeConverter instances
106-
:param statement: A SAML statement
107-
:return: A dictionary with attribute names replaced with local names.
108-
"""
109-
if not statement:
110-
return {}
111-
112-
acsdic = dict([(ac.name_format, ac) for ac in acs])
113-
acsdic[None] = acsdic[NAME_FORMAT_URI]
114-
return dict([acsdic[a.name_format].ava_from(a) for a in statement])
104+
# def ava_fro(acs, statement):
105+
# """ Translates attributes according to their name_formats into the local
106+
# names.
107+
#
108+
# :param acs: AttributeConverter instances
109+
# :param statement: A SAML statement
110+
# :return: A dictionary with attribute names replaced with local names.
111+
# """
112+
# if not statement:
113+
# return {}
114+
#
115+
# acsdic = dict([(ac.name_format, ac) for ac in acs])
116+
# acsdic[None] = acsdic[NAME_FORMAT_URI]
117+
# return dict([acsdic[a.name_format].ava_from(a) for a in statement])
115118

116119

117-
def to_local(acs, statement):
120+
def to_local(acs, statement, allow_unknown_attributes=False):
118121
""" Replaces the attribute names in a attribute value assertion with the
119122
equivalent name from a local name format.
120123
124+
:param acs: List of Attribute Converters
125+
:param statement: The Attribute Statement
126+
:param allow_unknown_attributes: If unknown attributes are allowed
127+
:return: A key,values dictionary
121128
"""
122129
if not acs:
123130
acs = [AttributeConverter()]
@@ -128,9 +135,26 @@ def to_local(acs, statement):
128135
ava = {}
129136
for attr in statement.attribute:
130137
try:
131-
key, val = acsd[attr.name_format].ava_from(attr)
138+
_func = acsd[attr.name_format].ava_from
132139
except KeyError:
133-
key, val = acs[0].lcd_ava_from(attr)
140+
if attr.name_format == NAME_FORMAT_UNSPECIFIED or \
141+
allow_unknown_attributes:
142+
_func = acs[0].lcd_ava_from
143+
else:
144+
logger.info("Unsupported attribute name format: %s" % (
145+
attr.name_format,))
146+
continue
147+
148+
try:
149+
key, val = _func(attr)
150+
except KeyError:
151+
if allow_unknown_attributes:
152+
key, val = acs[0].lcd_ava_from(attr)
153+
else:
154+
logger.info("Unknown attribute name: %s" % (attr,))
155+
continue
156+
except AttributeError:
157+
continue
134158

135159
try:
136160
ava[key].extend(val)
@@ -245,7 +269,7 @@ def lcd_ava_from(self, attribute):
245269
"""
246270
In nothing else works, this should
247271
248-
:param attribute:
272+
:param attribute: An Attribute Instance
249273
:return:
250274
"""
251275
try:
@@ -287,14 +311,19 @@ def fail_safe_fro(self, statement):
287311
result[name].append(value.text.strip())
288312
return result
289313

290-
def ava_from(self, attribute):
314+
def ava_from(self, attribute, allow_unknown=False):
291315
try:
292316
attr = self._fro[attribute.name.strip().lower()]
293-
except (AttributeError, KeyError):
294-
try:
295-
attr = attribute.friendly_name.strip().lower()
296-
except AttributeError:
297-
attr = attribute.name.strip().lower()
317+
except AttributeError:
318+
attr = attribute.friendly_name.strip().lower()
319+
except KeyError:
320+
if allow_unknown:
321+
try:
322+
attr = attribute.name.strip().lower()
323+
except AttributeError:
324+
attr = attribute.friendly_name.strip().lower()
325+
else:
326+
raise
298327

299328
val = []
300329
for value in attribute.attribute_value:
@@ -333,8 +362,12 @@ def fro(self, statement):
333362
attribute.name_format != self.name_format:
334363
continue
335364

336-
(key, val) = self.ava_from(attribute)
337-
result[key] = val
365+
try:
366+
(key, val) = self.ava_from(attribute)
367+
except (KeyError, AttributeError):
368+
pass
369+
else:
370+
result[key] = val
338371

339372
return result
340373

0 commit comments

Comments
 (0)