@@ -40,113 +40,53 @@ class OpenIDConnectFrontend(FrontendModule):
40
40
"""
41
41
42
42
def __init__ (self , auth_req_callback_func , internal_attributes , conf , base_url , name ):
43
- self . _validate_config (conf )
43
+ _validate_config (conf )
44
44
super ().__init__ (auth_req_callback_func , internal_attributes , base_url , name )
45
45
46
46
self .config = conf
47
- self .signing_key = RSAKey (key = rsa_load (conf ["signing_key_path" ]), use = "sig" , alg = "RS256" ,
48
- kid = conf .get ("signing_key_id" , "" ))
49
-
50
- def _create_provider (self , endpoint_baseurl ):
51
- response_types_supported = self .config ["provider" ].get ("response_types_supported" , ["id_token" ])
52
- subject_types_supported = self .config ["provider" ].get ("subject_types_supported" , ["pairwise" ])
53
- scopes_supported = self .config ["provider" ].get ("scopes_supported" , ["openid" ])
54
- extra_scopes = self .config ["provider" ].get ("extra_scopes" )
55
- capabilities = {
56
- "issuer" : self .base_url ,
57
- "authorization_endpoint" : "{}/{}" .format (endpoint_baseurl , AuthorizationEndpoint .url ),
58
- "jwks_uri" : "{}/jwks" .format (endpoint_baseurl ),
59
- "response_types_supported" : response_types_supported ,
60
- "id_token_signing_alg_values_supported" : [self .signing_key .alg ],
61
- "response_modes_supported" : ["fragment" , "query" ],
62
- "subject_types_supported" : subject_types_supported ,
63
- "claim_types_supported" : ["normal" ],
64
- "claims_parameter_supported" : True ,
65
- "claims_supported" : [attribute_map ["openid" ][0 ]
66
- for attribute_map in self .internal_attributes ["attributes" ].values ()
67
- if "openid" in attribute_map ],
68
- "request_parameter_supported" : False ,
69
- "request_uri_parameter_supported" : False ,
70
- "scopes_supported" : scopes_supported
71
- }
72
-
73
- if 'code' in response_types_supported :
74
- capabilities ["token_endpoint" ] = "{}/{}" .format (endpoint_baseurl , TokenEndpoint .url )
75
-
76
- if self .config ["provider" ].get ("client_registration_supported" , False ):
77
- capabilities ["registration_endpoint" ] = "{}/{}" .format (endpoint_baseurl , RegistrationEndpoint .url )
78
-
79
- authz_state = self ._init_authorization_state ()
47
+ provider_config = self .config ["provider" ]
48
+ provider_config ["issuer" ] = base_url
49
+
50
+ self .signing_key = RSAKey (
51
+ key = rsa_load (self .config ["signing_key_path" ]),
52
+ use = "sig" ,
53
+ alg = "RS256" ,
54
+ kid = self .config .get ("signing_key_id" , "" ),
55
+ )
56
+
80
57
db_uri = self .config .get ("db_uri" )
58
+ self .user_db = (
59
+ StorageBase .from_uri (db_uri , db_name = "satosa" , collection = "authz_codes" )
60
+ if db_uri
61
+ else {}
62
+ )
63
+
64
+ sub_hash_salt = self .config .get ("sub_hash_salt" , rndstr (16 ))
65
+ authz_state = _init_authorization_state (provider_config , db_uri , sub_hash_salt )
66
+
81
67
client_db_uri = self .config .get ("client_db_uri" )
82
68
cdb_file = self .config .get ("client_db_path" )
83
69
if client_db_uri :
84
70
cdb = StorageBase .from_uri (
85
- client_db_uri , db_name = "satosa" , collection = "clients"
71
+ client_db_uri , db_name = "satosa" , collection = "clients" , ttl = None
86
72
)
87
73
elif cdb_file :
88
74
with open (cdb_file ) as f :
89
75
cdb = json .loads (f .read ())
90
76
else :
91
77
cdb = {}
92
78
93
- #XXX What is the correct ttl for user_db? Is it the same as authz_code_db?
94
- self .user_db = (
95
- StorageBase .from_uri (db_uri , db_name = "satosa" , collection = "authz_codes" )
96
- if db_uri
97
- else {}
98
- )
99
-
100
- self .provider = Provider (
79
+ self .endpoint_baseurl = "{}/{}" .format (self .base_url , self .name )
80
+ self .provider = _create_provider (
81
+ provider_config ,
82
+ self .endpoint_baseurl ,
83
+ self .internal_attributes ,
101
84
self .signing_key ,
102
- capabilities ,
103
85
authz_state ,
86
+ self .user_db ,
104
87
cdb ,
105
- Userinfo (self .user_db ),
106
- extra_scopes = extra_scopes ,
107
- id_token_lifetime = self .config ["provider" ].get ("id_token_lifetime" , 3600 ),
108
88
)
109
89
110
- def _init_authorization_state (self ):
111
- sub_hash_salt = self .config .get ("sub_hash_salt" , rndstr (16 ))
112
- db_uri = self .config .get ("db_uri" )
113
- if db_uri :
114
- authz_code_db = StorageBase .from_uri (
115
- db_uri ,
116
- db_name = "satosa" ,
117
- collection = "authz_codes" ,
118
- ttl = self .config ["provider" ].get ("authorization_code_lifetime" , 600 ),
119
- )
120
- access_token_db = StorageBase .from_uri (
121
- db_uri ,
122
- db_name = "satosa" ,
123
- collection = "access_tokens" ,
124
- ttl = self .config ["provider" ].get ("access_token_lifetime" , 3600 ),
125
- )
126
- refresh_token_db = StorageBase .from_uri (
127
- db_uri ,
128
- db_name = "satosa" ,
129
- collection = "refresh_tokens" ,
130
- ttl = self .config ["provider" ].get ("refresh_token_lifetime" , None ),
131
- )
132
- #XXX what is the correct TTL for sub_db?
133
- sub_db = StorageBase .from_uri (
134
- db_uri , db_name = "satosa" , collection = "subject_identifiers"
135
- )
136
- else :
137
- authz_code_db = None
138
- access_token_db = None
139
- refresh_token_db = None
140
- sub_db = None
141
-
142
- token_lifetimes = {k : self .config ["provider" ][k ] for k in ["authorization_code_lifetime" ,
143
- "access_token_lifetime" ,
144
- "refresh_token_lifetime" ,
145
- "refresh_token_threshold" ]
146
- if k in self .config ["provider" ]}
147
- return AuthorizationState (HashBasedSubjectIdentifierFactory (sub_hash_salt ), authz_code_db , access_token_db ,
148
- refresh_token_db , sub_db , ** token_lifetimes )
149
-
150
90
def _get_extra_id_token_claims (self , user_id , client_id ):
151
91
if "extra_id_token_claims" in self .config ["provider" ]:
152
92
config = self .config ["provider" ]["extra_id_token_claims" ].get (client_id , [])
@@ -223,9 +163,6 @@ def register_endpoints(self, backend_names):
223
163
else :
224
164
backend_name = backend_names [0 ]
225
165
226
- endpoint_baseurl = "{}/{}" .format (self .base_url , self .name )
227
- self ._create_provider (endpoint_baseurl )
228
-
229
166
provider_config = ("^.well-known/openid-configuration$" , self .provider_config )
230
167
jwks_uri = ("^{}/jwks$" .format (self .name ), self .jwks )
231
168
@@ -236,42 +173,36 @@ def register_endpoints(self, backend_names):
236
173
auth_path = urlparse (auth_endpoint ).path .lstrip ("/" )
237
174
else :
238
175
auth_path = "{}/{}" .format (self .name , AuthorizationEndpoint .url )
176
+
239
177
authentication = ("^{}$" .format (auth_path ), self .handle_authn_request )
240
178
url_map = [provider_config , jwks_uri , authentication ]
241
179
242
180
if any ("code" in v for v in self .provider .configuration_information ["response_types_supported" ]):
243
- self .provider .configuration_information ["token_endpoint" ] = "{}/{}" .format (endpoint_baseurl ,
244
- TokenEndpoint .url )
245
- token_endpoint = ("^{}/{}" .format (self .name , TokenEndpoint .url ), self .token_endpoint )
181
+ self .provider .configuration_information ["token_endpoint" ] = "{}/{}" .format (
182
+ self .endpoint_baseurl , TokenEndpoint .url
183
+ )
184
+ token_endpoint = (
185
+ "^{}/{}" .format (self .name , TokenEndpoint .url ), self .token_endpoint
186
+ )
246
187
url_map .append (token_endpoint )
247
188
248
- self .provider .configuration_information ["userinfo_endpoint" ] = "{}/{}" .format (endpoint_baseurl ,
249
- UserinfoEndpoint .url )
250
- userinfo_endpoint = ("^{}/{}" .format (self .name , UserinfoEndpoint .url ), self .userinfo_endpoint )
189
+ self .provider .configuration_information ["userinfo_endpoint" ] = (
190
+ "{}/{}" .format (self .endpoint_baseurl , UserinfoEndpoint .url )
191
+ )
192
+ userinfo_endpoint = (
193
+ "^{}/{}" .format (self .name , UserinfoEndpoint .url ), self .userinfo_endpoint
194
+ )
251
195
url_map .append (userinfo_endpoint )
196
+
252
197
if "registration_endpoint" in self .provider .configuration_information :
253
- client_registration = ("^{}/{}" .format (self .name , RegistrationEndpoint .url ), self .client_registration )
198
+ client_registration = (
199
+ "^{}/{}" .format (self .name , RegistrationEndpoint .url ),
200
+ self .client_registration ,
201
+ )
254
202
url_map .append (client_registration )
255
203
256
204
return url_map
257
205
258
- def _validate_config (self , config ):
259
- """
260
- Validates that all necessary config parameters are specified.
261
- :type config: dict[str, dict[str, Any] | str]
262
- :param config: the module config
263
- """
264
- if config is None :
265
- raise ValueError ("OIDCFrontend conf can't be 'None'." )
266
-
267
- for k in {"signing_key_path" , "provider" }:
268
- if k not in config :
269
- raise ValueError ("Missing configuration parameter '{}' for OpenID Connect frontend." .format (k ))
270
-
271
- if "signing_key_id" in config and type (config ["signing_key_id" ]) is not str :
272
- raise ValueError (
273
- "The configuration parameter 'signing_key_id' is not defined as a string for OpenID Connect frontend." )
274
-
275
206
def _get_authn_request_from_state (self , state ):
276
207
"""
277
208
Extract the clietns request stoed in the SATOSA state.
@@ -438,6 +369,128 @@ def userinfo_endpoint(self, context):
438
369
return response
439
370
440
371
372
+ def _validate_config (config ):
373
+ """
374
+ Validates that all necessary config parameters are specified.
375
+ :type config: dict[str, dict[str, Any] | str]
376
+ :param config: the module config
377
+ """
378
+ if config is None :
379
+ raise ValueError ("OIDCFrontend configuration can't be 'None'." )
380
+
381
+ for k in {"signing_key_path" , "provider" }:
382
+ if k not in config :
383
+ raise ValueError ("Missing configuration parameter '{}' for OpenID Connect frontend." .format (k ))
384
+
385
+ if "signing_key_id" in config and type (config ["signing_key_id" ]) is not str :
386
+ raise ValueError (
387
+ "The configuration parameter 'signing_key_id' is not defined as a string for OpenID Connect frontend." )
388
+
389
+
390
+ def _create_provider (
391
+ provider_config ,
392
+ endpoint_baseurl ,
393
+ internal_attributes ,
394
+ signing_key ,
395
+ authz_state ,
396
+ user_db ,
397
+ cdb ,
398
+ ):
399
+ response_types_supported = provider_config .get ("response_types_supported" , ["id_token" ])
400
+ subject_types_supported = provider_config .get ("subject_types_supported" , ["pairwise" ])
401
+ scopes_supported = provider_config .get ("scopes_supported" , ["openid" ])
402
+ extra_scopes = provider_config .get ("extra_scopes" )
403
+ capabilities = {
404
+ "issuer" : provider_config ["issuer" ],
405
+ "authorization_endpoint" : "{}/{}" .format (endpoint_baseurl , AuthorizationEndpoint .url ),
406
+ "jwks_uri" : "{}/jwks" .format (endpoint_baseurl ),
407
+ "response_types_supported" : response_types_supported ,
408
+ "id_token_signing_alg_values_supported" : [signing_key .alg ],
409
+ "response_modes_supported" : ["fragment" , "query" ],
410
+ "subject_types_supported" : subject_types_supported ,
411
+ "claim_types_supported" : ["normal" ],
412
+ "claims_parameter_supported" : True ,
413
+ "claims_supported" : [
414
+ attribute_map ["openid" ][0 ]
415
+ for attribute_map in internal_attributes ["attributes" ].values ()
416
+ if "openid" in attribute_map
417
+ ],
418
+ "request_parameter_supported" : False ,
419
+ "request_uri_parameter_supported" : False ,
420
+ "scopes_supported" : scopes_supported
421
+ }
422
+
423
+ if 'code' in response_types_supported :
424
+ capabilities ["token_endpoint" ] = "{}/{}" .format (
425
+ endpoint_baseurl , TokenEndpoint .url
426
+ )
427
+
428
+ if provider_config .get ("client_registration_supported" , False ):
429
+ capabilities ["registration_endpoint" ] = "{}/{}" .format (
430
+ endpoint_baseurl , RegistrationEndpoint .url
431
+ )
432
+
433
+ provider = Provider (
434
+ signing_key ,
435
+ capabilities ,
436
+ authz_state ,
437
+ cdb ,
438
+ Userinfo (user_db ),
439
+ extra_scopes = extra_scopes ,
440
+ id_token_lifetime = provider_config .get ("id_token_lifetime" , 3600 ),
441
+ )
442
+ return provider
443
+
444
+
445
+ def _init_authorization_state (provider_config , db_uri , sub_hash_salt ):
446
+ if db_uri :
447
+ authz_code_db = StorageBase .from_uri (
448
+ db_uri ,
449
+ db_name = "satosa" ,
450
+ collection = "authz_codes" ,
451
+ ttl = provider_config .get ("authorization_code_lifetime" , 600 ),
452
+ )
453
+ access_token_db = StorageBase .from_uri (
454
+ db_uri ,
455
+ db_name = "satosa" ,
456
+ collection = "access_tokens" ,
457
+ ttl = provider_config .get ("access_token_lifetime" , 3600 ),
458
+ )
459
+ refresh_token_db = StorageBase .from_uri (
460
+ db_uri ,
461
+ db_name = "satosa" ,
462
+ collection = "refresh_tokens" ,
463
+ ttl = provider_config .get ("refresh_token_lifetime" , None ),
464
+ )
465
+ sub_db = StorageBase .from_uri (
466
+ db_uri , db_name = "satosa" , collection = "subject_identifiers" , ttl = None
467
+ )
468
+ else :
469
+ authz_code_db = None
470
+ access_token_db = None
471
+ refresh_token_db = None
472
+ sub_db = None
473
+
474
+ token_lifetimes = {
475
+ k : provider_config [k ]
476
+ for k in [
477
+ "authorization_code_lifetime" ,
478
+ "access_token_lifetime" ,
479
+ "refresh_token_lifetime" ,
480
+ "refresh_token_threshold" ,
481
+ ]
482
+ if k in provider_config
483
+ }
484
+ return AuthorizationState (
485
+ HashBasedSubjectIdentifierFactory (sub_hash_salt ),
486
+ authz_code_db ,
487
+ access_token_db ,
488
+ refresh_token_db ,
489
+ sub_db ,
490
+ ** token_lifetimes ,
491
+ )
492
+
493
+
441
494
def combine_return_input (values ):
442
495
return values
443
496
0 commit comments