@@ -166,7 +166,7 @@ def verify(self):
166
166
def set_type (self , typ ):
167
167
try :
168
168
del self .extension_attributes [XSI_NIL ]
169
- except KeyError :
169
+ except ( AttributeError , KeyError ) :
170
170
pass
171
171
172
172
try :
@@ -199,66 +199,90 @@ def clear_type(self):
199
199
except KeyError :
200
200
pass
201
201
202
- def set_text (self , val , base64encode = False ):
203
- typ = self .get_type ()
204
- if base64encode :
205
- val = _b64_encode_fn (val )
206
- self .set_type ("xs:base64Binary" )
207
- else :
208
- if isinstance (val , six .binary_type ):
209
- val = val .decode ()
210
- if isinstance (val , six .string_types ):
211
- if not typ :
212
- self .set_type ("xs:string" )
213
- else :
214
- try :
215
- assert typ == "xs:string"
216
- except AssertionError :
217
- if typ == "xs:int" :
218
- _ = int (val )
219
- elif typ == "xs:boolean" :
220
- if val .lower () not in ["true" , "false" ]:
221
- raise ValueError ("Not a boolean" )
222
- elif typ == "xs:float" :
223
- _ = float (val )
224
- elif typ == "xs:base64Binary" :
225
- pass
226
- else :
227
- raise ValueError ("Type and value doesn't match" )
228
- elif isinstance (val , bool ):
229
- if val :
230
- val = "true"
231
- else :
232
- val = "false"
233
- if not typ :
234
- self .set_type ("xs:boolean" )
235
- else :
236
- assert typ == "xs:boolean"
237
- elif isinstance (val , int ):
238
- val = str (val )
239
- if not typ :
240
- self .set_type ("xs:integer" )
241
- else :
242
- assert typ == "xs:integer"
243
- elif isinstance (val , float ):
244
- val = str (val )
245
- if not typ :
246
- self .set_type ("xs:float" )
247
- else :
248
- assert typ == "xs:float"
249
- elif not val :
250
- try :
251
- self .extension_attributes [XSI_TYPE ] = typ
252
- except AttributeError :
253
- self ._extatt [XSI_TYPE ] = typ
254
- val = ""
255
- else :
256
- if typ == "xs:anyType" :
257
- pass
258
- else :
259
- raise ValueError
260
-
261
- SamlBase .__setattr__ (self , "text" , val )
202
+ def set_text (self , value , base64encode = False ):
203
+ _xs_type_from_type = {
204
+ str : 'xs:string' ,
205
+ int : 'xs:integer' ,
206
+ float : 'xs:float' ,
207
+ bool : 'xs:boolean' ,
208
+ type (None ): '' ,
209
+ }
210
+
211
+ _type_from_xs_type = {
212
+ 'xs:anyType' : str ,
213
+ 'xs:string' : str ,
214
+ 'xs:integer' : int ,
215
+ 'xs:short' : int ,
216
+ 'xs:int' : int ,
217
+ 'xs:long' : int ,
218
+ 'xs:float' : float ,
219
+ 'xs:double' : float ,
220
+ 'xs:boolean' : bool ,
221
+ 'xs:base64Binary' : str ,
222
+ '' : type (None ),
223
+ }
224
+
225
+ _typed_value_constructor_from_xs_type = {
226
+ 'xs:anyType' : str ,
227
+ 'xs:string' : str ,
228
+ 'xs:integer' : int ,
229
+ 'xs:short' : int ,
230
+ 'xs:int' : int ,
231
+ 'xs:long' : int ,
232
+ 'xs:float' : float ,
233
+ 'xs:double' : float ,
234
+ 'xs:boolean' : lambda x :
235
+ True if str (x ).lower () == 'true'
236
+ else False if str (x ).lower () == 'false'
237
+ else None ,
238
+ 'xs:base64Binary' : str ,
239
+ '' : lambda x : None ,
240
+ }
241
+
242
+ _text_constructor_from_xs_type = {
243
+ 'xs:anyType' : str ,
244
+ 'xs:string' : str ,
245
+ 'xs:integer' : str ,
246
+ 'xs:short' : str ,
247
+ 'xs:int' : str ,
248
+ 'xs:long' : str ,
249
+ 'xs:float' : str ,
250
+ 'xs:double' : str ,
251
+ 'xs:boolean' : lambda x : str (x ).lower (),
252
+ 'xs:base64Binary' : lambda x :
253
+ _b64_encode_fn (x .encode ())
254
+ if base64encode
255
+ else x ,
256
+ '' : lambda x : '' ,
257
+ }
258
+
259
+ if isinstance (value , six .binary_type ):
260
+ value = value .decode ()
261
+
262
+ xs_type = \
263
+ 'xs:base64Binary' \
264
+ if base64encode \
265
+ else self .get_type () \
266
+ or _xs_type_from_type .get (type (value ))
267
+
268
+ if xs_type is None :
269
+ msg_tpl = 'No corresponding xs-type for {type}:{value}'
270
+ msg = msg_tpl .format (type = type (value ), value = value )
271
+ raise ValueError (msg )
272
+
273
+ valid_type = _type_from_xs_type .get (xs_type , type (None ))
274
+ to_typed = _typed_value_constructor_from_xs_type .get (xs_type , str )
275
+ to_text = _text_constructor_from_xs_type .get (xs_type , str )
276
+
277
+ value = to_typed (value )
278
+ if type (value ) is not valid_type :
279
+ msg_tpl = 'Type and value do not match: {type}:{value}'
280
+ msg = msg_tpl .format (type = xs_type , value = value )
281
+ raise ValueError (msg )
282
+
283
+ text = to_text (value )
284
+ self .set_type (xs_type )
285
+ SamlBase .__setattr__ (self , 'text' , text )
262
286
return self
263
287
264
288
def harvest_element_tree (self , tree ):
0 commit comments