16
16
from saml2 .s_utils import factory
17
17
from saml2 .s_utils import assertion_factory
18
18
19
-
20
19
logger = logging .getLogger (__name__ )
21
20
22
21
@@ -78,55 +77,54 @@ def filter_on_attributes(ava, required=None, optional=None, acs=None,
78
77
are missing fail or fail not depending on this parameter.
79
78
:return: The modified attribute value assertion
80
79
"""
80
+
81
+ def _match_attr_name (attr , ava ):
82
+ try :
83
+ friendly_name = attr ["friendly_name" ]
84
+ except KeyError :
85
+ friendly_name = get_local_name (acs , attr ["name" ], attr ["name_format" ])
86
+
87
+ _fn = _match (friendly_name , ava )
88
+ if not _fn : # In the unlikely case that someone has provided us with URIs as attribute names
89
+ _fn = _match (attr ["name" ], ava )
90
+
91
+ return _fn
92
+
93
+ def _apply_attr_value_restrictions (attr , res , must = False ):
94
+ try :
95
+ values = [av ["text" ] for av in attr ["attribute_value" ]]
96
+ except KeyError :
97
+ values = []
98
+
99
+ try :
100
+ res [_fn ].extend (_filter_values (ava [_fn ], values ))
101
+ except KeyError :
102
+ res [_fn ] = _filter_values (ava [_fn ], values )
103
+
104
+ return _filter_values (ava [_fn ], values , must )
105
+
81
106
res = {}
82
107
83
108
if required is None :
84
109
required = []
85
110
86
- nform = "friendly_name"
87
111
for attr in required :
88
- try :
89
- _name = attr [nform ]
90
- except KeyError :
91
- if nform == "friendly_name" :
92
- _name = get_local_name (acs , attr ["name" ],
93
- attr ["name_format" ])
94
- else :
95
- continue
96
-
97
- _fn = _match (_name , ava )
98
- if not _fn : # In the unlikely case that someone has provided us
99
- # with URIs as attribute names
100
- _fn = _match (attr ["name" ], ava )
112
+ _fn = _match_attr_name (attr , ava )
101
113
102
114
if _fn :
103
- try :
104
- values = [av ["text" ] for av in attr ["attribute_value" ]]
105
- except KeyError :
106
- values = []
107
- res [_fn ] = _filter_values (ava [_fn ], values , True )
108
- continue
115
+ _apply_attr_value_restrictions (attr , res , True )
109
116
elif fail_on_unfulfilled_requirements :
110
117
desc = "Required attribute missing: '%s' (%s)" % (attr ["name" ],
111
- _name )
118
+ _fn )
112
119
raise MissingValue (desc )
113
120
114
121
if optional is None :
115
122
optional = []
116
123
117
124
for attr in optional :
118
- for nform in ["friendly_name" , "name" ]:
119
- if nform in attr :
120
- _fn = _match (attr [nform ], ava )
121
- if _fn :
122
- try :
123
- values = [av ["text" ] for av in attr ["attribute_value" ]]
124
- except KeyError :
125
- values = []
126
- try :
127
- res [_fn ].extend (_filter_values (ava [_fn ], values ))
128
- except KeyError :
129
- res [_fn ] = _filter_values (ava [_fn ], values )
125
+ _fn = _match_attr_name (attr , ava )
126
+ if _fn :
127
+ _apply_attr_value_restrictions (attr , res , False )
130
128
131
129
return res
132
130
@@ -154,8 +152,8 @@ def filter_on_demands(ava, required=None, optional=None):
154
152
for val in vals :
155
153
if val not in ava [lava [attr ]]:
156
154
raise MissingValue (
157
- "Required attribute value missing: %s,%s" % (attr ,
158
- val ))
155
+ "Required attribute value missing: %s,%s" % (attr ,
156
+ val ))
159
157
else :
160
158
raise MissingValue ("Required attribute missing: %s" % (attr ,))
161
159
@@ -334,7 +332,7 @@ def compile(self, restrictions):
334
332
ecs = []
335
333
for cat in items :
336
334
_mod = importlib .import_module (
337
- "saml2.entity_category.%s" % cat )
335
+ "saml2.entity_category.%s" % cat )
338
336
_ec = {}
339
337
for key , items in _mod .RELEASE .items ():
340
338
_ec [key ] = [k .lower () for k in items ]
@@ -488,8 +486,8 @@ def filter(self, ava, sp_entity_id, mdstore, required=None, optional=None):
488
486
if required or optional :
489
487
logger .debug ("required: %s, optional: %s" , required , optional )
490
488
_ava = filter_on_attributes (
491
- ava .copy (), required , optional , self .acs ,
492
- self .get_fail_on_missing_requested (sp_entity_id ))
489
+ ava .copy (), required , optional , self .acs ,
490
+ self .get_fail_on_missing_requested (sp_entity_id ))
493
491
494
492
_rest = self .get_entity_categories (sp_entity_id , mdstore )
495
493
if _rest :
@@ -539,9 +537,9 @@ def conditions(self, sp_entity_id):
539
537
# How long might depend on who's getting it
540
538
not_on_or_after = self .not_on_or_after (sp_entity_id ),
541
539
audience_restriction = [factory (
542
- saml .AudienceRestriction ,
543
- audience = [factory (saml .Audience ,
544
- text = sp_entity_id )])])
540
+ saml .AudienceRestriction ,
541
+ audience = [factory (saml .Audience ,
542
+ text = sp_entity_id )])])
545
543
546
544
def get_sign (self , sp_entity_id ):
547
545
"""
@@ -571,7 +569,7 @@ def _authn_context_class_ref(authn_class, authn_auth=None):
571
569
return factory (saml .AuthnContext ,
572
570
authn_context_class_ref = cntx_class ,
573
571
authenticating_authority = factory (
574
- saml .AuthenticatingAuthority , text = authn_auth ))
572
+ saml .AuthenticatingAuthority , text = authn_auth ))
575
573
else :
576
574
return factory (saml .AuthnContext ,
577
575
authn_context_class_ref = cntx_class )
@@ -587,7 +585,7 @@ def _authn_context_decl(decl, authn_auth=None):
587
585
return factory (saml .AuthnContext ,
588
586
authn_context_decl = decl ,
589
587
authenticating_authority = factory (
590
- saml .AuthenticatingAuthority , text = authn_auth ))
588
+ saml .AuthenticatingAuthority , text = authn_auth ))
591
589
592
590
593
591
def _authn_context_decl_ref (decl_ref , authn_auth = None ):
@@ -600,7 +598,7 @@ def _authn_context_decl_ref(decl_ref, authn_auth=None):
600
598
return factory (saml .AuthnContext ,
601
599
authn_context_decl_ref = decl_ref ,
602
600
authenticating_authority = factory (
603
- saml .AuthenticatingAuthority , text = authn_auth ))
601
+ saml .AuthenticatingAuthority , text = authn_auth ))
604
602
605
603
606
604
def authn_statement (authn_class = None , authn_auth = None ,
@@ -626,29 +624,29 @@ def authn_statement(authn_class=None, authn_auth=None,
626
624
627
625
if authn_class :
628
626
res = factory (
629
- saml .AuthnStatement ,
630
- authn_instant = _instant ,
631
- session_index = sid (),
632
- authn_context = _authn_context_class_ref (
633
- authn_class , authn_auth ))
627
+ saml .AuthnStatement ,
628
+ authn_instant = _instant ,
629
+ session_index = sid (),
630
+ authn_context = _authn_context_class_ref (
631
+ authn_class , authn_auth ))
634
632
elif authn_decl :
635
633
res = factory (
636
- saml .AuthnStatement ,
637
- authn_instant = _instant ,
638
- session_index = sid (),
639
- authn_context = _authn_context_decl (authn_decl , authn_auth ))
634
+ saml .AuthnStatement ,
635
+ authn_instant = _instant ,
636
+ session_index = sid (),
637
+ authn_context = _authn_context_decl (authn_decl , authn_auth ))
640
638
elif authn_decl_ref :
641
639
res = factory (
642
- saml .AuthnStatement ,
643
- authn_instant = _instant ,
644
- session_index = sid (),
645
- authn_context = _authn_context_decl_ref (authn_decl_ref ,
646
- authn_auth ))
640
+ saml .AuthnStatement ,
641
+ authn_instant = _instant ,
642
+ session_index = sid (),
643
+ authn_context = _authn_context_decl_ref (authn_decl_ref ,
644
+ authn_auth ))
647
645
else :
648
646
res = factory (
649
- saml .AuthnStatement ,
650
- authn_instant = _instant ,
651
- session_index = sid ())
647
+ saml .AuthnStatement ,
648
+ authn_instant = _instant ,
649
+ session_index = sid ())
652
650
653
651
if subject_locality :
654
652
res .subject_locality = saml .SubjectLocality (text = subject_locality )
@@ -698,7 +696,7 @@ def construct(self, sp_entity_id, in_response_to, consumer_url,
698
696
_name_format = NAME_FORMAT_URI
699
697
700
698
attr_statement = saml .AttributeStatement (attribute = from_local (
701
- attrconvs , self , _name_format ))
699
+ attrconvs , self , _name_format ))
702
700
703
701
if encrypt == "attributes" :
704
702
for attr in attr_statement .attribute :
@@ -725,33 +723,33 @@ def construct(self, sp_entity_id, in_response_to, consumer_url,
725
723
726
724
if not add_subject :
727
725
_ass = assertion_factory (
728
- issuer = issuer ,
729
- conditions = conds ,
730
- subject = None
726
+ issuer = issuer ,
727
+ conditions = conds ,
728
+ subject = None
731
729
)
732
730
else :
733
731
_ass = assertion_factory (
734
- issuer = issuer ,
735
- conditions = conds ,
736
- subject = factory (
737
- saml .Subject ,
738
- name_id = name_id ,
739
- subject_confirmation = [factory (
740
- saml .SubjectConfirmation ,
741
- method = saml .SCM_BEARER ,
742
- subject_confirmation_data = factory (
743
- saml .SubjectConfirmationData ,
744
- in_response_to = in_response_to ,
745
- recipient = consumer_url ,
746
- not_on_or_after = policy .not_on_or_after (sp_entity_id )))]
747
- ),
732
+ issuer = issuer ,
733
+ conditions = conds ,
734
+ subject = factory (
735
+ saml .Subject ,
736
+ name_id = name_id ,
737
+ subject_confirmation = [factory (
738
+ saml .SubjectConfirmation ,
739
+ method = saml .SCM_BEARER ,
740
+ subject_confirmation_data = factory (
741
+ saml .SubjectConfirmationData ,
742
+ in_response_to = in_response_to ,
743
+ recipient = consumer_url ,
744
+ not_on_or_after = policy .not_on_or_after (sp_entity_id )))]
745
+ ),
748
746
)
749
747
750
748
if _authn_statement :
751
749
_ass .authn_statement = [_authn_statement ]
752
750
753
751
if not attr_statement .empty ():
754
- _ass .attribute_statement = [attr_statement ]
752
+ _ass .attribute_statement = [attr_statement ]
755
753
756
754
return _ass
757
755
0 commit comments