Skip to content

Commit f20b465

Browse files
authored
Merge pull request #92 from IdentityPython/entity_id2base_url
In case base_url is not defined use entity_id.
2 parents 195e0c2 + 023f058 commit f20b465

File tree

5 files changed

+36
-17
lines changed

5 files changed

+36
-17
lines changed

src/idpyoidc/claims.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""):
122122

123123
return keyjar, _uri_path
124124

125-
def get_base_url(self, configuration: dict):
125+
def get_base_url(self, configuration: dict, entity_id: Optional[str]=""):
126126
raise NotImplementedError()
127127

128128
def get_id(self, configuration: dict):
@@ -134,7 +134,10 @@ def add_extra_keys(self, keyjar, id):
134134
def get_jwks(self, keyjar):
135135
return keyjar.export_jwks()
136136

137-
def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None):
137+
def handle_keys(self,
138+
configuration: dict,
139+
keyjar: Optional[KeyJar] = None,
140+
entity_id: Optional[str] = ""):
138141
_jwks = _jwks_uri = None
139142
_id = self.get_id(configuration)
140143
keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id)
@@ -147,15 +150,19 @@ def handle_keys(self, configuration: dict, keyjar: Optional[KeyJar] = None):
147150
if "jwks_uri" in configuration: # simple
148151
_jwks_uri = configuration.get("jwks_uri")
149152
elif uri_path:
150-
_base_url = self.get_base_url(configuration)
153+
_base_url = self.get_base_url(configuration, entity_id=entity_id)
151154
_jwks_uri = add_path(_base_url, uri_path)
152155
else: # jwks or nothing
153156
_jwks = self.get_jwks(keyjar)
154157

155158
return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri}
156159

157160
def load_conf(
158-
self, configuration: dict, supports: dict, keyjar: Optional[KeyJar] = None
161+
self,
162+
configuration: dict,
163+
supports: dict,
164+
keyjar: Optional[KeyJar] = None,
165+
entity_id: Optional[str] = ""
159166
) -> KeyJar:
160167
for attr, val in configuration.items():
161168
if attr in ["preference", "capabilities"]:
@@ -167,7 +174,7 @@ def load_conf(
167174

168175
self.locals(configuration)
169176

170-
for key, val in self.handle_keys(configuration, keyjar=keyjar).items():
177+
for key, val in self.handle_keys(configuration, keyjar=keyjar, entity_id=entity_id).items():
171178
if key == "keyjar":
172179
keyjar = val
173180
elif val:

src/idpyoidc/client/claims/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional
2+
13
from cryptojwt import KeyJar
24
from cryptojwt.exception import IssuerNotFound
35
from cryptojwt.jwk.hmac import SYMKey
@@ -11,10 +13,13 @@ def get_client_authn_methods():
1113

1214

1315
class Claims(claims.Claims):
14-
def get_base_url(self, configuration: dict):
16+
def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""):
1517
_base = configuration.get("base_url")
1618
if not _base:
17-
_base = configuration.get("client_id")
19+
if entity_id:
20+
_base = entity_id
21+
else:
22+
_base = configuration.get("client_id")
1823

1924
return _base
2025

src/idpyoidc/client/service_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def __init__(
173173
for key, val in kwargs.items():
174174
setattr(self, key, val)
175175

176-
self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), keyjar=keyjar)
176+
self.keyjar = self.claims.load_conf(config.conf, supports=self.supports(), keyjar=keyjar,
177+
entity_id=self.entity_id)
177178

178179
_jwks_uri = self.provider_info.get("jwks_uri")
179180
if _jwks_uri:

src/idpyoidc/metadata.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import logging
21
from functools import cmp_to_key
2+
import logging
33
from typing import Callable
44
from typing import Optional
55

@@ -128,7 +128,7 @@ def _keyjar(self, keyjar=None, conf=None, entity_id=""):
128128
_uri_path = conf["key_conf"].get("uri_path")
129129
return keyjar, _uri_path
130130

131-
def get_base_url(self, configuration: dict):
131+
def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""):
132132
raise NotImplementedError()
133133

134134
def get_id(self, configuration: dict):
@@ -140,9 +140,11 @@ def add_extra_keys(self, keyjar, id):
140140
def get_jwks(self, keyjar):
141141
return None
142142

143-
def handle_keys(
144-
self, configuration: dict, keyjar: Optional[KeyJar] = None, base_url: Optional[str] = ""
145-
):
143+
def handle_keys(self,
144+
configuration: dict,
145+
keyjar: Optional[KeyJar] = None,
146+
base_url: Optional[str] = "",
147+
entity_id: Optional[str] = ""):
146148
_jwks = _jwks_uri = None
147149
_id = self.get_id(configuration)
148150
keyjar, uri_path = self._keyjar(keyjar, configuration, entity_id=_id)
@@ -154,15 +156,16 @@ def handle_keys(
154156
_jwks_uri = configuration.get("jwks_uri")
155157
elif uri_path:
156158
if not base_url:
157-
base_url = self.get_base_url(configuration)
159+
base_url = self.get_base_url(configuration, entity_id=entity_id)
158160
_jwks_uri = add_path(base_url, uri_path)
159161
else: # jwks or nothing
160162
_jwks = self.get_jwks(keyjar)
161163

162164
return {"keyjar": keyjar, "jwks": _jwks, "jwks_uri": _jwks_uri}
163165

164166
def load_conf(
165-
self, configuration, supports, keyjar: Optional[KeyJar] = None, base_url: Optional[str] = ""
167+
self, configuration, supports, keyjar: Optional[KeyJar] = None,
168+
base_url: Optional[str] = ""
166169
):
167170
for attr, val in configuration.items():
168171
if attr == "preference":

src/idpyoidc/server/claims/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44

55

66
class Claims(claims.Claims):
7-
def get_base_url(self, configuration: dict):
7+
def get_base_url(self, configuration: dict, entity_id: Optional[str] = ""):
88
_base = configuration.get("base_url")
99
if not _base:
10-
_base = configuration.get("issuer")
10+
if entity_id:
11+
_base = entity_id
12+
else:
13+
_base = configuration.get("issuer")
1114

1215
return _base
1316

0 commit comments

Comments
 (0)