Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.

Commit 363c921

Browse files
committed
Add claims per client in access token
1 parent 78baa6b commit 363c921

File tree

3 files changed

+35
-5
lines changed

3 files changed

+35
-5
lines changed

src/oidcendpoint/jwt_token.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@ def __init__(
3232

3333
self.key_jar = keyjar or ec.keyjar
3434
self.issuer = issuer or ec.issuer
35+
self.cdb = ec.cdb
3536

3637
self.def_aud = aud or []
3738
self.alg = alg
3839
self.scope_claims_map = kwargs.get("scope_claims_map", ec.scope2claims)
40+
self.enable_claims_per_client = kwargs.get(
41+
"enable_claims_per_client", False
42+
)
3943

4044
def add_claims(self, payload, uinfo, claims):
4145
for attr in claims:
@@ -47,7 +51,8 @@ def add_claims(self, payload, uinfo, claims):
4751
pass
4852

4953
def __call__(
50-
self, sid: str, uinfo: Dict, sinfo: Dict, *args, aud: Optional[Any], **kwargs
54+
self, sid: str, uinfo: Dict, sinfo: Dict, *args, aud: Optional[Any],
55+
client_id: Optional[str], **kwargs
5156
):
5257
"""
5358
Return a token.
@@ -70,6 +75,12 @@ def __call__(
7075
sinfo["authn_req"]["scope"], map=self.scope_claims_map
7176
).keys(),
7277
)
78+
# Add claims if is access token
79+
if self.type == 'T' and self.enable_claims_per_client:
80+
client = self.cdb.get(client_id, {})
81+
client_claims = client.get("access_token_claims")
82+
if client_claims:
83+
self.add_claims(payload, uinfo, client_claims)
7384

7485
payload.update(kwargs)
7586
signer = JWT(

src/oidcendpoint/session.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,14 +302,15 @@ def replace_refresh_token(self, sid, sinfo):
302302

303303
def _make_at(self, sid, session_info, aud=None, client_id_aud=True):
304304
uid = self.sso_db.get_uid_by_sid(sid)
305-
306-
uinfo = self.userinfo(uid, session_info["client_id"]) or {}
305+
client_id = session_info["client_id"]
306+
uinfo = self.userinfo(uid, client_id) or {}
307307
at_aud = aud or []
308308

309309
if client_id_aud:
310-
at_aud.append(session_info["client_id"])
310+
at_aud.append(client_id)
311311
return self.handler["access_token"](
312-
sid=sid, sinfo=session_info, uinfo=uinfo, aud=at_aud
312+
sid=sid, sinfo=session_info, uinfo=uinfo, aud=at_aud,
313+
client_id=client_id
313314
)
314315

315316
def upgrade_to_token(

tests/test_27_jwt_token.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,24 @@ def test_info(self):
189189
assert _info["type"] == "T"
190190
assert _info["sid"] == session_id
191191

192+
@pytest.mark.parametrize("enable_claims_per_client", [True, False])
193+
def test_client_claims(self, enable_claims_per_client):
194+
ec = self.endpoint.endpoint_context
195+
handler = ec.sdb.handler.handler["access_token"]
196+
session_id = setup_session(
197+
ec, AUTH_REQ, uid="diana"
198+
)
199+
ec.cdb["client_1"]['access_token_claims'] = {
200+
"address": None
201+
}
202+
handler.enable_claims_per_client = enable_claims_per_client
203+
_dic = ec.sdb.upgrade_to_token(key=session_id)
204+
205+
token = _dic["access_token"]
206+
_jwt = JWT(key_jar=KEYJAR, iss="client_1")
207+
res = _jwt.unpack(token)
208+
assert enable_claims_per_client is ("address" in res)
209+
192210
def test_is_expired(self):
193211
session_id = setup_session(
194212
self.endpoint.endpoint_context, AUTH_REQ, uid="diana"

0 commit comments

Comments
 (0)