@@ -270,6 +270,39 @@ def restriction_from_attribute_spec(attributes):
270
270
return restr
271
271
272
272
273
+ def post_entity_categories (maps , ** kwargs ):
274
+ restrictions = {}
275
+ if kwargs ["mds" ]:
276
+ try :
277
+ ecs = kwargs ["mds" ].entity_categories (kwargs ["sp_entity_id" ])
278
+ except KeyError :
279
+ for ec_map in maps :
280
+ for attr in ec_map ["" ]:
281
+ restrictions [attr ] = None
282
+ else :
283
+ for ec_map in maps :
284
+ for key , val in ec_map .items ():
285
+ if key == "" : # always released
286
+ attrs = val
287
+ elif isinstance (key , tuple ):
288
+ attrs = val
289
+ for _key in key :
290
+ try :
291
+ assert _key in ecs
292
+ except AssertionError :
293
+ attrs = []
294
+ break
295
+ elif key in ecs :
296
+ attrs = val
297
+ else :
298
+ attrs = []
299
+
300
+ for attr in attrs :
301
+ restrictions [attr ] = None
302
+
303
+ return restrictions
304
+
305
+
273
306
class Policy (object ):
274
307
""" handles restrictions on assertions """
275
308
@@ -329,85 +362,70 @@ def compile(self, restrictions):
329
362
logger .debug ("policy restrictions: %s" % self ._restrictions )
330
363
331
364
return self ._restrictions
332
-
365
+
366
+ def get (self , attribute , sp_entity_id , default = None , post_func = None ,
367
+ ** kwargs ):
368
+ """
369
+
370
+ :param attribute:
371
+ :param sp_entity_id:
372
+ :param default:
373
+ :param post_func:
374
+ :return:
375
+ """
376
+ if not self ._restrictions :
377
+ return default
378
+
379
+ try :
380
+ try :
381
+ val = self ._restrictions [sp_entity_id ][attribute ]
382
+ except KeyError :
383
+ try :
384
+ val = self ._restrictions ["default" ][attribute ]
385
+ except KeyError :
386
+ val = None
387
+ except KeyError :
388
+ val = None
389
+
390
+ if val is None :
391
+ return default
392
+ elif post_func :
393
+ return post_func (val , sp_entity_id = sp_entity_id , ** kwargs )
394
+ else :
395
+ return val
396
+
333
397
def get_nameid_format (self , sp_entity_id ):
334
398
""" Get the NameIDFormat to used for the entity id
335
399
:param: The SP entity ID
336
400
:retur: The format
337
401
"""
338
- try :
339
- form = self ._restrictions [sp_entity_id ]["nameid_format" ]
340
- except KeyError :
341
- try :
342
- form = self ._restrictions ["default" ]["nameid_format" ]
343
- except KeyError :
344
- form = saml .NAMEID_FORMAT_TRANSIENT
345
-
346
- return form
347
-
402
+ return self .get ("nameid_format" , sp_entity_id ,
403
+ saml .NAMEID_FORMAT_TRANSIENT )
404
+
348
405
def get_name_form (self , sp_entity_id ):
349
406
""" Get the NameFormat to used for the entity id
350
407
:param: The SP entity ID
351
408
:retur: The format
352
409
"""
353
- form = NAME_FORMAT_URI
354
-
355
- try :
356
- form = self ._restrictions [sp_entity_id ]["name_form" ]
357
- except TypeError :
358
- pass
359
- except KeyError :
360
- try :
361
- form = self ._restrictions ["default" ]["name_form" ]
362
- except KeyError :
363
- pass
364
-
365
- return form
366
-
410
+
411
+ return self .get ("name_format" , sp_entity_id , NAME_FORMAT_URI )
412
+
367
413
def get_lifetime (self , sp_entity_id ):
368
414
""" The lifetime of the assertion
369
415
:param sp_entity_id: The SP entity ID
370
416
:param: lifetime as a dictionary
371
417
"""
372
418
# default is a hour
373
- spec = {"hours" : 1 }
374
- if not self ._restrictions :
375
- return spec
376
-
377
- try :
378
- spec = self ._restrictions [sp_entity_id ]["lifetime" ]
379
- except KeyError :
380
- try :
381
- spec = self ._restrictions ["default" ]["lifetime" ]
382
- except KeyError :
383
- pass
384
-
385
- return spec
386
-
387
- def get_attribute_restriction (self , sp_entity_id ):
419
+ return self .get ("lifetime" , sp_entity_id , {"hours" : 1 })
420
+
421
+ def get_attribute_restrictions (self , sp_entity_id ):
388
422
""" Return the attribute restriction for SP that want the information
389
423
390
424
:param sp_entity_id: The SP entity ID
391
425
:return: The restrictions
392
426
"""
393
-
394
- if not self ._restrictions :
395
- return None
396
-
397
- try :
398
- try :
399
- restrictions = self ._restrictions [sp_entity_id ][
400
- "attribute_restrictions" ]
401
- except KeyError :
402
- try :
403
- restrictions = self ._restrictions ["default" ][
404
- "attribute_restrictions" ]
405
- except KeyError :
406
- restrictions = None
407
- except KeyError :
408
- restrictions = None
409
-
410
- return restrictions
427
+
428
+ return self .get ("attribute_restrictions" , sp_entity_id )
411
429
412
430
def entity_category_attributes (self , ec ):
413
431
if not self ._restrictions :
@@ -421,59 +439,18 @@ def entity_category_attributes(self, ec):
421
439
pass
422
440
return []
423
441
424
- def get_entity_categories_restriction (self , sp_entity_id , mds ):
442
+ def get_entity_categories (self , sp_entity_id , mds ):
425
443
"""
426
444
427
445
:param sp_entity_id:
428
446
:param mds: MetadataStore instance
429
- :return: A dictionary with restrictionsmetat
447
+ :return: A dictionary with restrictions
430
448
"""
431
- if not self ._restrictions :
432
- return None
433
-
434
- restrictions = {}
435
- ec_maps = []
436
- try :
437
- try :
438
- ec_maps = self ._restrictions [sp_entity_id ]["entity_categories" ]
439
- except KeyError :
440
- try :
441
- ec_maps = self ._restrictions ["default" ]["entity_categories" ]
442
- except KeyError :
443
- pass
444
- except KeyError :
445
- pass
446
449
447
- if ec_maps :
448
- if mds :
449
- try :
450
- ecs = mds .entity_categories (sp_entity_id )
451
- except KeyError :
452
- for ec_map in ec_maps :
453
- for attr in ec_map ["" ]:
454
- restrictions [attr ] = None
455
- else :
456
- for ec_map in ec_maps :
457
- for key , val in ec_map .items ():
458
- if key == "" : # always released
459
- attrs = val
460
- elif isinstance (key , tuple ):
461
- attrs = val
462
- for _key in key :
463
- try :
464
- assert _key in ecs
465
- except AssertionError :
466
- attrs = []
467
- break
468
- elif key in ecs :
469
- attrs = val
470
- else :
471
- attrs = []
450
+ kwargs = {"mds" : mds }
472
451
473
- for attr in attrs :
474
- restrictions [attr ] = None
475
-
476
- return restrictions
452
+ return self .get ("entity_categories" , sp_entity_id , default = {},
453
+ post_func = post_entity_categories , ** kwargs )
477
454
478
455
def not_on_or_after (self , sp_entity_id ):
479
456
""" When the assertion stops being valid, should not be
@@ -500,10 +477,9 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
500
477
:return: A possibly modified AVA
501
478
"""
502
479
503
- _rest = self .get_attribute_restriction (sp_entity_id )
480
+ _rest = self .get_attribute_restrictions (sp_entity_id )
504
481
if _rest is None :
505
- _rest = self .get_entity_categories_restriction (sp_entity_id ,
506
- mdstore )
482
+ _rest = self .get_entity_categories (sp_entity_id , mdstore )
507
483
logger .debug ("filter based on: %s" % _rest )
508
484
ava = filter_attribute_value_assertions (ava , _rest )
509
485
@@ -543,6 +519,17 @@ def conditions(self, sp_entity_id):
543
519
audience = [factory (saml .Audience ,
544
520
text = sp_entity_id )])])
545
521
522
+ def get_sign (self , sp_entity_id ):
523
+ """
524
+ Possible choices
525
+ "sign": ["response", "assertion", "on_demand"]
526
+
527
+ :param sp_entity_id:
528
+ :return:
529
+ """
530
+
531
+ return self .get ("sign" , sp_entity_id , [])
532
+
546
533
547
534
class EntityCategories (object ):
548
535
pass
0 commit comments