@@ -320,37 +320,54 @@ def post_entity_categories(maps, **kwargs):
320
320
class Policy (object ):
321
321
""" handles restrictions on assertions """
322
322
323
- def __init__ (self , restrictions = None ):
324
- if restrictions :
325
- self .compile (restrictions )
326
- else :
327
- self ._restrictions = None
323
+ def __init__ (self , restrictions = None , config = None ):
324
+ self ._config = config
325
+ self ._restrictions = self .setup_restrictions (restrictions )
326
+ logger .debug ("policy restrictions: %s" , self ._restrictions )
328
327
self .acs = []
329
328
330
- def compile (self , restrictions ):
329
+ def setup_restrictions (self , restrictions = None ):
330
+ if restrictions is None :
331
+ return None
332
+
333
+ restrictions = copy .deepcopy (restrictions )
334
+ # TODO: Split policy config in service_providers and registration_authorities
335
+ # "policy": {
336
+ # "service_providers": {
337
+ # "default": ...,
338
+ # "urn:mace:example.com:saml:roland:sp": ...,
339
+ # },
340
+ # "registration_authorities": {
341
+ # "default": ...,
342
+ # "http://www.swamid.se": ...,
343
+ # },
344
+ # },
345
+ registration_authorities = restrictions .pop ('registration_authorities' , None )
346
+ restrictions = self .compile (restrictions )
347
+ if registration_authorities :
348
+ restrictions ['registration_authorities' ] = self .compile (registration_authorities )
349
+ return restrictions
350
+
351
+ @staticmethod
352
+ def compile (restrictions ):
331
353
""" This is only for IdPs or AAs, and it's about limiting what
332
354
is returned to the SP.
333
355
In the configuration file, restrictions on which values that
334
356
can be returned are specified with the help of regular expressions.
335
357
This function goes through and pre-compiles the regular expressions.
336
358
337
- :param restrictions:
359
+ :param restrictions: policy configuration
338
360
:return: The assertion with the string specification replaced with
339
361
a compiled regular expression.
340
362
"""
341
-
342
- self ._restrictions = copy .deepcopy (restrictions )
343
-
344
- for who , spec in self ._restrictions .items ():
363
+ for who , spec in restrictions .items ():
345
364
if spec is None :
346
365
continue
347
- try :
348
- items = spec ["entity_categories" ]
349
- except KeyError :
350
- pass
351
- else :
366
+
367
+ entity_categories = spec .get ("entity_categories" )
368
+ if entity_categories is not None :
352
369
ecs = []
353
- for cat in items :
370
+ for cat in entity_categories :
354
371
try :
355
372
_mod = importlib .import_module (cat )
356
373
except ImportError :
@@ -366,25 +383,27 @@ def compile(self, restrictions):
366
383
_ec [key ] = (alist , _only_required )
367
384
ecs .append (_ec )
368
385
spec ["entity_categories" ] = ecs
369
- try :
370
- restr = spec ["attribute_restrictions" ]
371
- except KeyError :
372
- continue
373
386
374
- if restr is None :
387
+ attribute_restrictions = spec .get ("attribute_restrictions" )
388
+ if attribute_restrictions is None :
375
389
continue
376
390
377
- _are = {}
378
- for key , values in restr .items ():
391
+ _attribute_restrictions = {}
392
+ for key , values in attribute_restrictions .items ():
379
393
if not values :
380
- _are [key .lower ()] = None
394
+ _attribute_restrictions [key .lower ()] = None
381
395
continue
396
+ _attribute_restrictions [key .lower ()] = [re .compile (value ) for value in values ]
382
397
383
- _are [key .lower ()] = [re .compile (value ) for value in values ]
384
- spec ["attribute_restrictions" ] = _are
385
- logger .debug ("policy restrictions: %s" , self ._restrictions )
398
+ spec ["attribute_restrictions" ] = _attribute_restrictions
386
399
387
- return self ._restrictions
400
+ return restrictions
401
+
402
+ def _lookup_registry_authority (self , sp_entity_id ):
403
+ if self ._config and self ._config .metadata :
404
+ registration_info = self ._config .metadata .registration_info (sp_entity_id )
405
+ return registration_info .get ('registration_authority' )
406
+ return None
388
407
389
408
def get (self , attribute , sp_entity_id , default = None , post_func = None ,
390
409
** kwargs ):
@@ -399,16 +418,22 @@ def get(self, attribute, sp_entity_id, default=None, post_func=None,
399
418
if not self ._restrictions :
400
419
return default
401
420
402
- try :
403
- try :
404
- val = self ._restrictions [sp_entity_id ][attribute ]
405
- except KeyError :
406
- try :
407
- val = self ._restrictions ["default" ][attribute ]
408
- except KeyError :
409
- val = None
410
- except KeyError :
411
- val = None
421
+ registration_authority_name = self ._lookup_registry_authority (sp_entity_id )
422
+ registration_authorities = self ._restrictions .get ("registration_authorities" )
423
+
424
+ val = None
425
+ # Specific SP takes precedence
426
+ if sp_entity_id in self ._restrictions :
427
+ val = self ._restrictions [sp_entity_id ].get (attribute )
428
+ # Second choice is if the SP is part of a configured registration authority
429
+ elif registration_authorities and registration_authority_name in registration_authorities :
430
+ val = registration_authorities [registration_authority_name ].get (attribute )
431
+ # Third is to try default for registration authorities
432
+ elif registration_authorities and 'default' in registration_authorities :
433
+ val = registration_authorities ['default' ].get (attribute )
434
+ # Lastly we try default for SPs
435
+ elif 'default' in self ._restrictions :
436
+ val = self ._restrictions .get ('default' ).get (attribute )
412
437
413
438
if val is None :
414
439
return default
@@ -422,16 +447,15 @@ def get_nameid_format(self, sp_entity_id):
422
447
:param: The SP entity ID
423
448
:retur: The format
424
449
"""
425
- return self .get ("nameid_format" , sp_entity_id ,
426
- saml .NAMEID_FORMAT_TRANSIENT )
450
+ return self .get ("nameid_format" , sp_entity_id , saml .NAMEID_FORMAT_TRANSIENT )
427
451
428
452
def get_name_form (self , sp_entity_id ):
429
453
""" Get the NameFormat to used for the entity id
430
454
:param: The SP entity ID
431
455
:retur: The format
432
456
"""
433
457
434
- return self .get ("name_form" , sp_entity_id , NAME_FORMAT_URI )
458
+ return self .get ("name_form" , sp_entity_id , default = NAME_FORMAT_URI )
435
459
436
460
def get_lifetime (self , sp_entity_id ):
437
461
""" The lifetime of the assertion
@@ -458,32 +482,20 @@ def get_fail_on_missing_requested(self, sp_entity_id):
458
482
:return: The restrictions
459
483
"""
460
484
461
- return self .get ("fail_on_missing_requested" , sp_entity_id , True )
462
-
463
- def entity_category_attributes (self , ec ):
464
- if not self ._restrictions :
465
- return None
466
-
467
- ec_maps = self ._restrictions ["default" ]["entity_categories" ]
468
- for ec_map in ec_maps :
469
- try :
470
- return ec_map [ec ]
471
- except KeyError :
472
- pass
473
- return []
485
+ return self .get ("fail_on_missing_requested" , sp_entity_id , default = True )
474
486
475
487
def get_entity_categories (self , sp_entity_id , mds , required ):
476
488
"""
477
489
478
490
:param sp_entity_id:
479
491
:param mds: MetadataStore instance
492
+ :param required: required attributes
480
493
:return: A dictionary with restrictions
481
494
"""
482
495
483
496
kwargs = {"mds" : mds , 'required' : required }
484
497
485
- return self .get ("entity_categories" , sp_entity_id , default = {},
486
- post_func = post_entity_categories , ** kwargs )
498
+ return self .get ("entity_categories" , sp_entity_id , default = {}, post_func = post_entity_categories , ** kwargs )
487
499
488
500
def not_on_or_after (self , sp_entity_id ):
489
501
""" When the assertion stops being valid, should not be
@@ -495,6 +507,17 @@ def not_on_or_after(self, sp_entity_id):
495
507
496
508
return in_a_while (** self .get_lifetime (sp_entity_id ))
497
509
510
+ def get_sign (self , sp_entity_id ):
511
+ """
512
+ Possible choices
513
+ "sign": ["response", "assertion", "on_demand"]
514
+
515
+ :param sp_entity_id:
516
+ :return:
517
+ """
518
+
519
+ return self .get ("sign" , sp_entity_id , default = [])
520
+
498
521
def filter (self , ava , sp_entity_id , mdstore , required = None , optional = None ):
499
522
""" What attribute and attribute values returns depends on what
500
523
the SP has said it wants in the request or in the metadata file and
@@ -568,16 +591,18 @@ def conditions(self, sp_entity_id):
568
591
audience = [factory (saml .Audience ,
569
592
text = sp_entity_id )])])
570
593
571
- def get_sign (self , sp_entity_id ):
572
- """
573
- Possible choices
574
- "sign": ["response", "assertion", "on_demand"]
575
-
576
- :param sp_entity_id:
577
- :return:
578
- """
594
+ def entity_category_attributes (self , ec ):
595
+ # TODO: Not used. Remove?
596
+ if not self ._restrictions :
597
+ return None
579
598
580
- return self .get ("sign" , sp_entity_id , [])
599
+ ec_maps = self ._restrictions ["default" ]["entity_categories" ]
600
+ for ec_map in ec_maps :
601
+ try :
602
+ return ec_map [ec ]
603
+ except KeyError :
604
+ pass
605
+ return []
581
606
582
607
583
608
class EntityCategories (object ):
0 commit comments