Skip to content

Commit 3c8d95f

Browse files
committed
Restructure initialization
Signed-off-by: Ivan Kanakarakis <[email protected]>
1 parent 250a6e7 commit 3c8d95f

File tree

1 file changed

+167
-114
lines changed

1 file changed

+167
-114
lines changed

src/satosa/frontends/openid_connect.py

Lines changed: 167 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -40,113 +40,53 @@ class OpenIDConnectFrontend(FrontendModule):
4040
"""
4141

4242
def __init__(self, auth_req_callback_func, internal_attributes, conf, base_url, name):
43-
self._validate_config(conf)
43+
_validate_config(conf)
4444
super().__init__(auth_req_callback_func, internal_attributes, base_url, name)
4545

4646
self.config = conf
47-
self.signing_key = RSAKey(key=rsa_load(conf["signing_key_path"]), use="sig", alg="RS256",
48-
kid=conf.get("signing_key_id", ""))
49-
50-
def _create_provider(self, endpoint_baseurl):
51-
response_types_supported = self.config["provider"].get("response_types_supported", ["id_token"])
52-
subject_types_supported = self.config["provider"].get("subject_types_supported", ["pairwise"])
53-
scopes_supported = self.config["provider"].get("scopes_supported", ["openid"])
54-
extra_scopes = self.config["provider"].get("extra_scopes")
55-
capabilities = {
56-
"issuer": self.base_url,
57-
"authorization_endpoint": "{}/{}".format(endpoint_baseurl, AuthorizationEndpoint.url),
58-
"jwks_uri": "{}/jwks".format(endpoint_baseurl),
59-
"response_types_supported": response_types_supported,
60-
"id_token_signing_alg_values_supported": [self.signing_key.alg],
61-
"response_modes_supported": ["fragment", "query"],
62-
"subject_types_supported": subject_types_supported,
63-
"claim_types_supported": ["normal"],
64-
"claims_parameter_supported": True,
65-
"claims_supported": [attribute_map["openid"][0]
66-
for attribute_map in self.internal_attributes["attributes"].values()
67-
if "openid" in attribute_map],
68-
"request_parameter_supported": False,
69-
"request_uri_parameter_supported": False,
70-
"scopes_supported": scopes_supported
71-
}
72-
73-
if 'code' in response_types_supported:
74-
capabilities["token_endpoint"] = "{}/{}".format(endpoint_baseurl, TokenEndpoint.url)
75-
76-
if self.config["provider"].get("client_registration_supported", False):
77-
capabilities["registration_endpoint"] = "{}/{}".format(endpoint_baseurl, RegistrationEndpoint.url)
78-
79-
authz_state = self._init_authorization_state()
47+
provider_config = self.config["provider"]
48+
provider_config["issuer"] = base_url
49+
50+
self.signing_key = RSAKey(
51+
key=rsa_load(self.config["signing_key_path"]),
52+
use="sig",
53+
alg="RS256",
54+
kid=self.config.get("signing_key_id", ""),
55+
)
56+
8057
db_uri = self.config.get("db_uri")
58+
self.user_db = (
59+
StorageBase.from_uri(db_uri, db_name="satosa", collection="authz_codes")
60+
if db_uri
61+
else {}
62+
)
63+
64+
sub_hash_salt = self.config.get("sub_hash_salt", rndstr(16))
65+
authz_state = _init_authorization_state(provider_config, db_uri, sub_hash_salt)
66+
8167
client_db_uri = self.config.get("client_db_uri")
8268
cdb_file = self.config.get("client_db_path")
8369
if client_db_uri:
8470
cdb = StorageBase.from_uri(
85-
client_db_uri, db_name="satosa", collection="clients"
71+
client_db_uri, db_name="satosa", collection="clients", ttl=None
8672
)
8773
elif cdb_file:
8874
with open(cdb_file) as f:
8975
cdb = json.loads(f.read())
9076
else:
9177
cdb = {}
9278

93-
#XXX What is the correct ttl for user_db? Is it the same as authz_code_db?
94-
self.user_db = (
95-
StorageBase.from_uri(db_uri, db_name="satosa", collection="authz_codes")
96-
if db_uri
97-
else {}
98-
)
99-
100-
self.provider = Provider(
79+
self.endpoint_baseurl = "{}/{}".format(self.base_url, self.name)
80+
self.provider = _create_provider(
81+
provider_config,
82+
self.endpoint_baseurl,
83+
self.internal_attributes,
10184
self.signing_key,
102-
capabilities,
10385
authz_state,
86+
self.user_db,
10487
cdb,
105-
Userinfo(self.user_db),
106-
extra_scopes=extra_scopes,
107-
id_token_lifetime=self.config["provider"].get("id_token_lifetime", 3600),
10888
)
10989

110-
def _init_authorization_state(self):
111-
sub_hash_salt = self.config.get("sub_hash_salt", rndstr(16))
112-
db_uri = self.config.get("db_uri")
113-
if db_uri:
114-
authz_code_db = StorageBase.from_uri(
115-
db_uri,
116-
db_name="satosa",
117-
collection="authz_codes",
118-
ttl=self.config["provider"].get("authorization_code_lifetime", 600),
119-
)
120-
access_token_db = StorageBase.from_uri(
121-
db_uri,
122-
db_name="satosa",
123-
collection="access_tokens",
124-
ttl=self.config["provider"].get("access_token_lifetime", 3600),
125-
)
126-
refresh_token_db = StorageBase.from_uri(
127-
db_uri,
128-
db_name="satosa",
129-
collection="refresh_tokens",
130-
ttl=self.config["provider"].get("refresh_token_lifetime", None),
131-
)
132-
#XXX what is the correct TTL for sub_db?
133-
sub_db = StorageBase.from_uri(
134-
db_uri, db_name="satosa", collection="subject_identifiers"
135-
)
136-
else:
137-
authz_code_db = None
138-
access_token_db = None
139-
refresh_token_db = None
140-
sub_db = None
141-
142-
token_lifetimes = {k: self.config["provider"][k] for k in ["authorization_code_lifetime",
143-
"access_token_lifetime",
144-
"refresh_token_lifetime",
145-
"refresh_token_threshold"]
146-
if k in self.config["provider"]}
147-
return AuthorizationState(HashBasedSubjectIdentifierFactory(sub_hash_salt), authz_code_db, access_token_db,
148-
refresh_token_db, sub_db, **token_lifetimes)
149-
15090
def _get_extra_id_token_claims(self, user_id, client_id):
15191
if "extra_id_token_claims" in self.config["provider"]:
15292
config = self.config["provider"]["extra_id_token_claims"].get(client_id, [])
@@ -223,9 +163,6 @@ def register_endpoints(self, backend_names):
223163
else:
224164
backend_name = backend_names[0]
225165

226-
endpoint_baseurl = "{}/{}".format(self.base_url, self.name)
227-
self._create_provider(endpoint_baseurl)
228-
229166
provider_config = ("^.well-known/openid-configuration$", self.provider_config)
230167
jwks_uri = ("^{}/jwks$".format(self.name), self.jwks)
231168

@@ -236,42 +173,36 @@ def register_endpoints(self, backend_names):
236173
auth_path = urlparse(auth_endpoint).path.lstrip("/")
237174
else:
238175
auth_path = "{}/{}".format(self.name, AuthorizationEndpoint.url)
176+
239177
authentication = ("^{}$".format(auth_path), self.handle_authn_request)
240178
url_map = [provider_config, jwks_uri, authentication]
241179

242180
if any("code" in v for v in self.provider.configuration_information["response_types_supported"]):
243-
self.provider.configuration_information["token_endpoint"] = "{}/{}".format(endpoint_baseurl,
244-
TokenEndpoint.url)
245-
token_endpoint = ("^{}/{}".format(self.name, TokenEndpoint.url), self.token_endpoint)
181+
self.provider.configuration_information["token_endpoint"] = "{}/{}".format(
182+
self.endpoint_baseurl, TokenEndpoint.url
183+
)
184+
token_endpoint = (
185+
"^{}/{}".format(self.name, TokenEndpoint.url), self.token_endpoint
186+
)
246187
url_map.append(token_endpoint)
247188

248-
self.provider.configuration_information["userinfo_endpoint"] = "{}/{}".format(endpoint_baseurl,
249-
UserinfoEndpoint.url)
250-
userinfo_endpoint = ("^{}/{}".format(self.name, UserinfoEndpoint.url), self.userinfo_endpoint)
189+
self.provider.configuration_information["userinfo_endpoint"] = (
190+
"{}/{}".format(self.endpoint_baseurl, UserinfoEndpoint.url)
191+
)
192+
userinfo_endpoint = (
193+
"^{}/{}".format(self.name, UserinfoEndpoint.url), self.userinfo_endpoint
194+
)
251195
url_map.append(userinfo_endpoint)
196+
252197
if "registration_endpoint" in self.provider.configuration_information:
253-
client_registration = ("^{}/{}".format(self.name, RegistrationEndpoint.url), self.client_registration)
198+
client_registration = (
199+
"^{}/{}".format(self.name, RegistrationEndpoint.url),
200+
self.client_registration,
201+
)
254202
url_map.append(client_registration)
255203

256204
return url_map
257205

258-
def _validate_config(self, config):
259-
"""
260-
Validates that all necessary config parameters are specified.
261-
:type config: dict[str, dict[str, Any] | str]
262-
:param config: the module config
263-
"""
264-
if config is None:
265-
raise ValueError("OIDCFrontend conf can't be 'None'.")
266-
267-
for k in {"signing_key_path", "provider"}:
268-
if k not in config:
269-
raise ValueError("Missing configuration parameter '{}' for OpenID Connect frontend.".format(k))
270-
271-
if "signing_key_id" in config and type(config["signing_key_id"]) is not str:
272-
raise ValueError(
273-
"The configuration parameter 'signing_key_id' is not defined as a string for OpenID Connect frontend.")
274-
275206
def _get_authn_request_from_state(self, state):
276207
"""
277208
Extract the clietns request stoed in the SATOSA state.
@@ -438,6 +369,128 @@ def userinfo_endpoint(self, context):
438369
return response
439370

440371

372+
def _validate_config(config):
373+
"""
374+
Validates that all necessary config parameters are specified.
375+
:type config: dict[str, dict[str, Any] | str]
376+
:param config: the module config
377+
"""
378+
if config is None:
379+
raise ValueError("OIDCFrontend configuration can't be 'None'.")
380+
381+
for k in {"signing_key_path", "provider"}:
382+
if k not in config:
383+
raise ValueError("Missing configuration parameter '{}' for OpenID Connect frontend.".format(k))
384+
385+
if "signing_key_id" in config and type(config["signing_key_id"]) is not str:
386+
raise ValueError(
387+
"The configuration parameter 'signing_key_id' is not defined as a string for OpenID Connect frontend.")
388+
389+
390+
def _create_provider(
391+
provider_config,
392+
endpoint_baseurl,
393+
internal_attributes,
394+
signing_key,
395+
authz_state,
396+
user_db,
397+
cdb,
398+
):
399+
response_types_supported = provider_config.get("response_types_supported", ["id_token"])
400+
subject_types_supported = provider_config.get("subject_types_supported", ["pairwise"])
401+
scopes_supported = provider_config.get("scopes_supported", ["openid"])
402+
extra_scopes = provider_config.get("extra_scopes")
403+
capabilities = {
404+
"issuer": provider_config["issuer"],
405+
"authorization_endpoint": "{}/{}".format(endpoint_baseurl, AuthorizationEndpoint.url),
406+
"jwks_uri": "{}/jwks".format(endpoint_baseurl),
407+
"response_types_supported": response_types_supported,
408+
"id_token_signing_alg_values_supported": [signing_key.alg],
409+
"response_modes_supported": ["fragment", "query"],
410+
"subject_types_supported": subject_types_supported,
411+
"claim_types_supported": ["normal"],
412+
"claims_parameter_supported": True,
413+
"claims_supported": [
414+
attribute_map["openid"][0]
415+
for attribute_map in internal_attributes["attributes"].values()
416+
if "openid" in attribute_map
417+
],
418+
"request_parameter_supported": False,
419+
"request_uri_parameter_supported": False,
420+
"scopes_supported": scopes_supported
421+
}
422+
423+
if 'code' in response_types_supported:
424+
capabilities["token_endpoint"] = "{}/{}".format(
425+
endpoint_baseurl, TokenEndpoint.url
426+
)
427+
428+
if provider_config.get("client_registration_supported", False):
429+
capabilities["registration_endpoint"] = "{}/{}".format(
430+
endpoint_baseurl, RegistrationEndpoint.url
431+
)
432+
433+
provider = Provider(
434+
signing_key,
435+
capabilities,
436+
authz_state,
437+
cdb,
438+
Userinfo(user_db),
439+
extra_scopes=extra_scopes,
440+
id_token_lifetime=provider_config.get("id_token_lifetime", 3600),
441+
)
442+
return provider
443+
444+
445+
def _init_authorization_state(provider_config, db_uri, sub_hash_salt):
446+
if db_uri:
447+
authz_code_db = StorageBase.from_uri(
448+
db_uri,
449+
db_name="satosa",
450+
collection="authz_codes",
451+
ttl=provider_config.get("authorization_code_lifetime", 600),
452+
)
453+
access_token_db = StorageBase.from_uri(
454+
db_uri,
455+
db_name="satosa",
456+
collection="access_tokens",
457+
ttl=provider_config.get("access_token_lifetime", 3600),
458+
)
459+
refresh_token_db = StorageBase.from_uri(
460+
db_uri,
461+
db_name="satosa",
462+
collection="refresh_tokens",
463+
ttl=provider_config.get("refresh_token_lifetime", None),
464+
)
465+
sub_db = StorageBase.from_uri(
466+
db_uri, db_name="satosa", collection="subject_identifiers", ttl=None
467+
)
468+
else:
469+
authz_code_db = None
470+
access_token_db = None
471+
refresh_token_db = None
472+
sub_db = None
473+
474+
token_lifetimes = {
475+
k: provider_config[k]
476+
for k in [
477+
"authorization_code_lifetime",
478+
"access_token_lifetime",
479+
"refresh_token_lifetime",
480+
"refresh_token_threshold",
481+
]
482+
if k in provider_config
483+
}
484+
return AuthorizationState(
485+
HashBasedSubjectIdentifierFactory(sub_hash_salt),
486+
authz_code_db,
487+
access_token_db,
488+
refresh_token_db,
489+
sub_db,
490+
**token_lifetimes,
491+
)
492+
493+
441494
def combine_return_input(values):
442495
return values
443496

0 commit comments

Comments
 (0)