@@ -82,10 +82,12 @@ def _match_attr_name(attr, ava):
82
82
try :
83
83
friendly_name = attr ["friendly_name" ]
84
84
except KeyError :
85
- friendly_name = get_local_name (acs , attr ["name" ], attr ["name_format" ])
85
+ friendly_name = get_local_name (acs , attr ["name" ],
86
+ attr ["name_format" ])
86
87
87
88
_fn = _match (friendly_name , ava )
88
- if not _fn : # In the unlikely case that someone has provided us with URIs as attribute names
89
+ if not _fn : # In the unlikely case that someone has provided us with
90
+ # URIs as attribute names
89
91
_fn = _match (attr ["name" ], ava )
90
92
91
93
return _fn
@@ -152,8 +154,8 @@ def filter_on_demands(ava, required=None, optional=None):
152
154
for val in vals :
153
155
if val not in ava [lava [attr ]]:
154
156
raise MissingValue (
155
- "Required attribute value missing: %s,%s" % (attr ,
156
- val ))
157
+ "Required attribute value missing: %s,%s" % (attr ,
158
+ val ))
157
159
else :
158
160
raise MissingValue ("Required attribute missing: %s" % (attr ,))
159
161
@@ -266,6 +268,11 @@ def restriction_from_attribute_spec(attributes):
266
268
267
269
def post_entity_categories (maps , ** kwargs ):
268
270
restrictions = {}
271
+ try :
272
+ required = [d ['friendly_name' ].lower () for d in kwargs ['required' ]]
273
+ except KeyError :
274
+ required = []
275
+
269
276
if kwargs ["mds" ]:
270
277
try :
271
278
ecs = kwargs ["mds" ].entity_categories (kwargs ["sp_entity_id" ])
@@ -275,19 +282,25 @@ def post_entity_categories(maps, **kwargs):
275
282
restrictions [attr ] = None
276
283
else :
277
284
for ec_map in maps :
278
- for key , val in ec_map .items ():
285
+ for key , ( atlist , only_required ) in ec_map .items ():
279
286
if key == "" : # always released
280
- attrs = val
287
+ attrs = atlist
281
288
elif isinstance (key , tuple ):
282
- attrs = val
289
+ if only_required :
290
+ attrs = [a for a in atlist if a in required ]
291
+ else :
292
+ attrs = atlist
283
293
for _key in key :
284
294
try :
285
295
assert _key in ecs
286
296
except AssertionError :
287
297
attrs = []
288
298
break
289
299
elif key in ecs :
290
- attrs = val
300
+ if only_required :
301
+ attrs = [a for a in atlist if a in required ]
302
+ else :
303
+ attrs = atlist
291
304
else :
292
305
attrs = []
293
306
@@ -332,10 +345,15 @@ def compile(self, restrictions):
332
345
ecs = []
333
346
for cat in items :
334
347
_mod = importlib .import_module (
335
- "saml2.entity_category.%s" % cat )
348
+ "saml2.entity_category.%s" % cat )
336
349
_ec = {}
337
350
for key , items in _mod .RELEASE .items ():
338
- _ec [key ] = [k .lower () for k in items ]
351
+ alist = [k .lower () for k in items ]
352
+ try :
353
+ _only_required = _mod .ONLY_REQUIRED [key ]
354
+ except (AttributeError , KeyError ):
355
+ _only_required = False
356
+ _ec [key ] = (alist , _only_required )
339
357
ecs .append (_ec )
340
358
spec ["entity_categories" ] = ecs
341
359
try :
@@ -444,15 +462,15 @@ def entity_category_attributes(self, ec):
444
462
pass
445
463
return []
446
464
447
- def get_entity_categories (self , sp_entity_id , mds ):
465
+ def get_entity_categories (self , sp_entity_id , mds , required ):
448
466
"""
449
467
450
468
:param sp_entity_id:
451
469
:param mds: MetadataStore instance
452
470
:return: A dictionary with restrictions
453
471
"""
454
472
455
- kwargs = {"mds" : mds }
473
+ kwargs = {"mds" : mds , 'required' : required }
456
474
457
475
return self .get ("entity_categories" , sp_entity_id , default = {},
458
476
post_func = post_entity_categories , ** kwargs )
@@ -483,19 +501,15 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
483
501
"""
484
502
485
503
_ava = None
486
- if required or optional :
487
- logger .debug ("required: %s, optional: %s" , required , optional )
488
- _ava = filter_on_attributes (
489
- ava .copy (), required , optional , self .acs ,
490
- self .get_fail_on_missing_requested (sp_entity_id ))
491
504
492
- _rest = self .get_entity_categories (sp_entity_id , mdstore )
505
+ _rest = self .get_entity_categories (sp_entity_id , mdstore , required )
493
506
if _rest :
494
- ava_ec = filter_attribute_value_assertions (ava .copy (), _rest )
495
- if _ava is None :
496
- _ava = ava_ec
497
- else :
498
- _ava .update (ava_ec )
507
+ _ava = filter_attribute_value_assertions (ava .copy (), _rest )
508
+ elif required or optional :
509
+ logger .debug ("required: %s, optional: %s" , required , optional )
510
+ _ava = filter_on_attributes (
511
+ ava .copy (), required , optional , self .acs ,
512
+ self .get_fail_on_missing_requested (sp_entity_id ))
499
513
500
514
_rest = self .get_attribute_restrictions (sp_entity_id )
501
515
if _rest :
@@ -537,9 +551,9 @@ def conditions(self, sp_entity_id):
537
551
# How long might depend on who's getting it
538
552
not_on_or_after = self .not_on_or_after (sp_entity_id ),
539
553
audience_restriction = [factory (
540
- saml .AudienceRestriction ,
541
- audience = [factory (saml .Audience ,
542
- text = sp_entity_id )])])
554
+ saml .AudienceRestriction ,
555
+ audience = [factory (saml .Audience ,
556
+ text = sp_entity_id )])])
543
557
544
558
def get_sign (self , sp_entity_id ):
545
559
"""
@@ -569,7 +583,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None):
569
583
return factory (saml .AuthnContext ,
570
584
authn_context_class_ref = cntx_class ,
571
585
authenticating_authority = factory (
572
- saml .AuthenticatingAuthority , text = authn_auth ))
586
+ saml .AuthenticatingAuthority , text = authn_auth ))
573
587
else :
574
588
return factory (saml .AuthnContext ,
575
589
authn_context_class_ref = cntx_class )
@@ -585,7 +599,7 @@ def _authn_context_decl(decl, authn_auth=None):
585
599
return factory (saml .AuthnContext ,
586
600
authn_context_decl = decl ,
587
601
authenticating_authority = factory (
588
- saml .AuthenticatingAuthority , text = authn_auth ))
602
+ saml .AuthenticatingAuthority , text = authn_auth ))
589
603
590
604
591
605
def _authn_context_decl_ref (decl_ref , authn_auth = None ):
@@ -598,7 +612,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None):
598
612
return factory (saml .AuthnContext ,
599
613
authn_context_decl_ref = decl_ref ,
600
614
authenticating_authority = factory (
601
- saml .AuthenticatingAuthority , text = authn_auth ))
615
+ saml .AuthenticatingAuthority , text = authn_auth ))
602
616
603
617
604
618
def authn_statement (authn_class = None , authn_auth = None ,
@@ -624,29 +638,29 @@ def authn_statement(authn_class=None, authn_auth=None,
624
638
625
639
if authn_class :
626
640
res = factory (
627
- saml .AuthnStatement ,
628
- authn_instant = _instant ,
629
- session_index = sid (),
630
- authn_context = _authn_context_class_ref (
631
- authn_class , authn_auth ))
641
+ saml .AuthnStatement ,
642
+ authn_instant = _instant ,
643
+ session_index = sid (),
644
+ authn_context = _authn_context_class_ref (
645
+ authn_class , authn_auth ))
632
646
elif authn_decl :
633
647
res = factory (
634
- saml .AuthnStatement ,
635
- authn_instant = _instant ,
636
- session_index = sid (),
637
- authn_context = _authn_context_decl (authn_decl , authn_auth ))
648
+ saml .AuthnStatement ,
649
+ authn_instant = _instant ,
650
+ session_index = sid (),
651
+ authn_context = _authn_context_decl (authn_decl , authn_auth ))
638
652
elif authn_decl_ref :
639
653
res = factory (
640
- saml .AuthnStatement ,
641
- authn_instant = _instant ,
642
- session_index = sid (),
643
- authn_context = _authn_context_decl_ref (authn_decl_ref ,
644
- authn_auth ))
654
+ saml .AuthnStatement ,
655
+ authn_instant = _instant ,
656
+ session_index = sid (),
657
+ authn_context = _authn_context_decl_ref (authn_decl_ref ,
658
+ authn_auth ))
645
659
else :
646
660
res = factory (
647
- saml .AuthnStatement ,
648
- authn_instant = _instant ,
649
- session_index = sid ())
661
+ saml .AuthnStatement ,
662
+ authn_instant = _instant ,
663
+ session_index = sid ())
650
664
651
665
if subject_locality :
652
666
res .subject_locality = saml .SubjectLocality (text = subject_locality )
@@ -688,7 +702,8 @@ def do_subject(policy, sp_entity_id, name_id, **farg):
688
702
specs = farg ['subject_confirmation' ]
689
703
690
704
if isinstance (specs , list ):
691
- res = [do_subject_confirmation (policy , sp_entity_id , ** s ) for s in specs ]
705
+ res = [do_subject_confirmation (policy , sp_entity_id , ** s ) for s in
706
+ specs ]
692
707
else :
693
708
res = [do_subject_confirmation (policy , sp_entity_id , ** specs )]
694
709
@@ -736,7 +751,7 @@ def construct(self, sp_entity_id, attrconvs, policy, issuer, farg,
736
751
_name_format = NAME_FORMAT_URI
737
752
738
753
attr_statement = saml .AttributeStatement (attribute = from_local (
739
- attrconvs , self , _name_format ))
754
+ attrconvs , self , _name_format ))
740
755
741
756
if encrypt == "attributes" :
742
757
for attr in attr_statement .attribute :
0 commit comments