Skip to content

Commit 8474950

Browse files
Merge pull request #545 from c00kiemon5ter/refactor-av-set-text
Refactor AttributeValueBase set_text method
2 parents e2675f7 + 23b3bc5 commit 8474950

File tree

2 files changed

+175
-119
lines changed

2 files changed

+175
-119
lines changed

src/saml2/saml.py

Lines changed: 125 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -83,51 +83,14 @@
8383
SCM_SENDER_VOUCHES = "urn:oasis:names:tc:SAML:2.0:cm:sender-vouches"
8484
SCM_BEARER = "urn:oasis:names:tc:SAML:2.0:cm:bearer"
8585

86-
XSD = "xs:"
86+
XSD = "xs"
8787
NS_SOAP_ENC = "http://schemas.xmlsoap.org/soap/encoding/"
8888

8989

9090
_b64_decode_fn = getattr(base64, 'decodebytes', base64.decodestring)
9191
_b64_encode_fn = getattr(base64, 'encodebytes', base64.encodestring)
9292

9393

94-
def _decode_attribute_value(typ, text):
95-
if typ == XSD + "string":
96-
return text or ""
97-
if typ == XSD + "integer" or typ == XSD + "int":
98-
return str(int(text))
99-
if typ == XSD + "float" or typ == XSD + "double":
100-
return str(float(text))
101-
if typ == XSD + "boolean":
102-
return str(text.lower() == "true")
103-
if typ == XSD + "base64Binary":
104-
return _b64_decode_fn(text)
105-
raise ValueError("type %s not supported" % type)
106-
107-
108-
def _verify_value_type(typ, val):
109-
# print("verify value type: %s, %s" % (typ, val))
110-
if typ == XSD + "string":
111-
try:
112-
return str(val)
113-
except UnicodeEncodeError:
114-
if six.PY2:
115-
return unicode(val)
116-
else:
117-
return val.decode('utf8')
118-
if typ == XSD + "integer" or typ == XSD + "int":
119-
return int(val)
120-
if typ == XSD + "float" or typ == XSD + "double":
121-
return float(val)
122-
if typ == XSD + "boolean":
123-
if val.lower() == "true" or val.lower() == "false":
124-
pass
125-
else:
126-
raise ValueError("Faulty boolean value")
127-
if typ == XSD + "base64Binary":
128-
return _b64_decode_fn(val.encode())
129-
130-
13194
class AttributeValueBase(SamlBase):
13295
def __init__(self,
13396
text=None,
@@ -166,7 +129,7 @@ def verify(self):
166129
def set_type(self, typ):
167130
try:
168131
del self.extension_attributes[XSI_NIL]
169-
except KeyError:
132+
except (AttributeError, KeyError):
170133
pass
171134

172135
try:
@@ -199,66 +162,129 @@ def clear_type(self):
199162
except KeyError:
200163
pass
201164

202-
def set_text(self, val, base64encode=False):
203-
typ = self.get_type()
204-
if base64encode:
205-
val = _b64_encode_fn(val)
206-
self.set_type("xs:base64Binary")
207-
else:
208-
if isinstance(val, six.binary_type):
209-
val = val.decode()
210-
if isinstance(val, six.string_types):
211-
if not typ:
212-
self.set_type("xs:string")
213-
else:
214-
try:
215-
assert typ == "xs:string"
216-
except AssertionError:
217-
if typ == "xs:int":
218-
_ = int(val)
219-
elif typ == "xs:boolean":
220-
if val.lower() not in ["true", "false"]:
221-
raise ValueError("Not a boolean")
222-
elif typ == "xs:float":
223-
_ = float(val)
224-
elif typ == "xs:base64Binary":
225-
pass
226-
else:
227-
raise ValueError("Type and value doesn't match")
228-
elif isinstance(val, bool):
229-
if val:
230-
val = "true"
231-
else:
232-
val = "false"
233-
if not typ:
234-
self.set_type("xs:boolean")
235-
else:
236-
assert typ == "xs:boolean"
237-
elif isinstance(val, int):
238-
val = str(val)
239-
if not typ:
240-
self.set_type("xs:integer")
241-
else:
242-
assert typ == "xs:integer"
243-
elif isinstance(val, float):
244-
val = str(val)
245-
if not typ:
246-
self.set_type("xs:float")
247-
else:
248-
assert typ == "xs:float"
249-
elif not val:
250-
try:
251-
self.extension_attributes[XSI_TYPE] = typ
252-
except AttributeError:
253-
self._extatt[XSI_TYPE] = typ
254-
val = ""
255-
else:
256-
if typ == "xs:anyType":
257-
pass
258-
else:
259-
raise ValueError
260-
261-
SamlBase.__setattr__(self, "text", val)
165+
def set_text(self, value, base64encode=False):
166+
def _wrong_type_value(xsd, value):
167+
msg = _str('Type and value do not match: {xsd}:{type}:{value}')
168+
msg = msg.format(xsd=xsd, type=type(value), value=value)
169+
raise ValueError(msg)
170+
171+
# only work with six.string_types
172+
_str = unicode if six.PY2 else str
173+
if isinstance(value, six.binary_type):
174+
value = value.decode()
175+
176+
type_to_xsd = {
177+
_str: 'string',
178+
int: 'integer',
179+
float: 'float',
180+
bool: 'boolean',
181+
type(None): '',
182+
}
183+
184+
# entries of xsd-types each declaring:
185+
# - a corresponding python type
186+
# - a function to turn a string into that type
187+
# - a function to turn that type into a text-value
188+
xsd_types_props = {
189+
'string': {
190+
'type': _str,
191+
'to_type': _str,
192+
'to_text': _str,
193+
},
194+
'integer': {
195+
'type': int,
196+
'to_type': int,
197+
'to_text': _str,
198+
},
199+
'short': {
200+
'type': int,
201+
'to_type': int,
202+
'to_text': _str,
203+
},
204+
'int': {
205+
'type': int,
206+
'to_type': int,
207+
'to_text': _str,
208+
},
209+
'long': {
210+
'type': int,
211+
'to_type': int,
212+
'to_text': _str,
213+
},
214+
'float': {
215+
'type': float,
216+
'to_type': float,
217+
'to_text': _str,
218+
},
219+
'double': {
220+
'type': float,
221+
'to_type': float,
222+
'to_text': _str,
223+
},
224+
'boolean': {
225+
'type': bool,
226+
'to_type': lambda x: {
227+
'true': True,
228+
'false': False,
229+
}[_str(x).lower()],
230+
'to_text': lambda x: _str(x).lower(),
231+
},
232+
'base64Binary': {
233+
'type': _str,
234+
'to_type': _str,
235+
'to_text': lambda x:
236+
_b64_encode_fn(x.encode())
237+
if base64encode
238+
else x,
239+
},
240+
'anyType': {
241+
'type': type(value),
242+
'to_type': lambda x: x,
243+
'to_text': lambda x: x,
244+
},
245+
'': {
246+
'type': type(None),
247+
'to_type': lambda x: None,
248+
'to_text': lambda x: '',
249+
},
250+
}
251+
252+
xsd_string = (
253+
'base64Binary' if base64encode
254+
else self.get_type()
255+
or type_to_xsd.get(type(value)))
256+
257+
xsd_ns, xsd_type = (
258+
['', type(None)] if xsd_string is None
259+
else ['', ''] if xsd_string is ''
260+
else [
261+
XSD if xsd_string in xsd_types_props else '',
262+
xsd_string
263+
] if ':' not in xsd_string
264+
else xsd_string.split(':', 1))
265+
266+
xsd_type_props = xsd_types_props.get(xsd_type, {})
267+
valid_type = xsd_type_props.get('type', type(None))
268+
to_type = xsd_type_props.get('to_type', str)
269+
to_text = xsd_type_props.get('to_text', str)
270+
271+
# cast to correct type before type-checking
272+
if type(value) is _str and valid_type is not _str:
273+
try:
274+
value = to_type(value)
275+
except (TypeError, ValueError, KeyError) as e:
276+
# the cast failed
277+
_wrong_type_value(xsd=xsd_type, value=value)
278+
279+
if type(value) is not valid_type:
280+
_wrong_type_value(xsd=xsd_type, value=value)
281+
282+
text = to_text(value)
283+
self.set_type(
284+
'{ns}:{type}'.format(ns=xsd_ns, type=xsd_type) if xsd_ns
285+
else xsd_type if xsd_type
286+
else '')
287+
SamlBase.__setattr__(self, 'text', text)
262288
return self
263289

264290
def harvest_element_tree(self, tree):
@@ -274,11 +300,6 @@ def harvest_element_tree(self, tree):
274300
self.set_text(tree.text)
275301
if XSI_NIL in self.extension_attributes:
276302
del self.extension_attributes[XSI_NIL]
277-
try:
278-
typ = self.extension_attributes[XSI_TYPE]
279-
_verify_value_type(typ, getattr(self, "text"))
280-
except KeyError:
281-
pass
282303

283304

284305
class BaseIDAbstractType_(SamlBase):

tests/test_02_saml.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,15 @@ def test_loadd(self):
4343
print(ava)
4444
ee = saml2.ExtensionElement("")
4545

46-
raises(KeyError, "ee.loadd(ava)")
46+
with raises(KeyError):
47+
ee.loadd(ava)
4748

4849
ava["tag"] = "foo"
4950
del ava["namespace"]
5051

5152
ee = saml2.ExtensionElement("")
52-
raises(KeyError, "ee.loadd(ava)")
53+
with raises(KeyError):
54+
ee.loadd(ava)
5355

5456
def test_find_children(self):
5557
ava = {
@@ -211,8 +213,8 @@ def test_make_vals_str(self):
211213
def test_make_vals_multi_dict(self):
212214
ava = ["foo", "bar", "lions", "saints"]
213215

214-
raises(Exception,
215-
"saml2.make_vals(ava, AttributeValue, Attribute(), part=True)")
216+
with raises(Exception):
217+
saml2.make_vals(ava, AttributeValue, Attribute(), part=True)
216218

217219
attr = Attribute()
218220
saml2.make_vals(ava, AttributeValue, attr, prop="attribute_value")
@@ -230,22 +232,55 @@ def test_to_string_nspair(self):
230232
assert "saml:AttributeValue" in nsstr
231233
assert "saml:AttributeValue" not in txt
232234

233-
def test_set_text(self):
235+
def test_set_text_empty(self):
236+
av = AttributeValue()
237+
av.set_text(None)
238+
assert av.get_type() == ''
239+
assert av.text == ''
240+
241+
def test_set_text_value(self):
242+
value = 123
243+
av = AttributeValue(value)
244+
assert av.get_type() == 'xs:integer'
245+
assert av.text == str(value)
246+
247+
def test_set_text_update_same_type(self):
234248
av = AttributeValue()
235249
av.set_text(True)
236-
assert av.text == "true"
250+
assert av.get_type() == 'xs:boolean'
251+
assert av.text == 'true'
237252
av.set_text(False)
238-
assert av.text == "false"
239-
# can't change value to another type
240-
raises(AssertionError, "av.set_text(491)")
253+
assert av.get_type() == 'xs:boolean'
254+
assert av.text == 'false'
241255

256+
def test_set_text_cannot_change_value_type(self):
242257
av = AttributeValue()
243-
av.set_text(None)
244-
assert av.text == ""
245-
258+
av.set_text(True)
259+
assert av.get_type() == 'xs:boolean'
260+
assert av.text == 'true'
261+
with raises(ValueError):
262+
av.set_text(123)
263+
assert av.get_type() == 'xs:boolean'
264+
assert av.text == 'true'
265+
266+
def test_set_xs_type_anytype_unchanged_value(self):
267+
av = AttributeValue()
268+
av.set_type('xs:anyType')
269+
for value in [
270+
[1, 2, 3],
271+
{'key': 'value'},
272+
True,
273+
123,
274+
]:
275+
av.set_text(value)
276+
# the value is unchanged
277+
assert av.text == value
278+
279+
def test_set_invalid_type_before_text(self):
246280
av = AttributeValue()
247-
av.set_type('invalid')
248-
raises(ValueError, "av.set_text('free text')")
281+
av.set_type('invalid-type')
282+
with raises(ValueError):
283+
av.set_text('foobar')
249284

250285
def test_make_vals_div(self):
251286
foo = saml2.make_vals(666, AttributeValue, part=True)
@@ -621,7 +656,7 @@ def testUsingTestData(self):
621656
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
622657
NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"
623658
Name="FirstName">
624-
<AttributeValue
659+
<AttributeValue
625660
xsi:type="xs:base64Binary">VU5JTkVUVA==</AttributeValue>
626661
</Attribute>"""
627662

0 commit comments

Comments
 (0)