Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.

Commit 133be36

Browse files
authored
Merge pull request #48 from nsklikas/feature-claims-per-client
Add claims per client in id token
2 parents 8edb984 + 60c3287 commit 133be36

File tree

2 files changed

+134
-4
lines changed

2 files changed

+134
-4
lines changed

src/oidcendpoint/id_token.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,9 @@ class IDToken(object):
112112
def __init__(self, endpoint_context, **kwargs):
113113
self.endpoint_context = endpoint_context
114114
self.kwargs = kwargs
115+
self.enable_claims_per_client = kwargs.get(
116+
'enable_claims_per_client', False
117+
)
115118
self.scope_to_claims = None
116119
self.provider_info = construct_endpoint_info(
117120
self.default_capabilities, **kwargs
@@ -242,7 +245,6 @@ def sign_encrypt(
242245

243246
def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs):
244247
_context = self.endpoint_context
245-
_sdb = _context.sdb
246248

247249
if authn_req:
248250
_client_id = authn_req["client_id"]
@@ -251,11 +253,13 @@ def make(self, req, sess_info, authn_req=None, user_claims=False, **kwargs):
251253

252254
_cinfo = _context.cdb[_client_id]
253255

254-
default_idtoken_claims = dict(self.kwargs.get("default_claims", {}))
256+
idtoken_claims = dict(self.kwargs.get("default_claims", {}))
257+
if self.enable_claims_per_client:
258+
idtoken_claims.update(_cinfo.get("id_token_claims", {}))
255259
lifetime = self.kwargs.get("lifetime")
256260

257261
userinfo = userinfo_in_id_token_claims(
258-
_context, sess_info, default_idtoken_claims
262+
_context, sess_info, idtoken_claims
259263
)
260264

261265
if user_claims:

tests/test_03_id_token.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def full_path(local_file):
3030
return os.path.join(BASEDIR, local_file)
3131

3232

33-
USERINFO = UserInfo(json.loads(open(full_path("users.json")).read()))
33+
USERS = json.loads(open(full_path("users.json")).read())
34+
USERINFO = UserInfo(USERS)
3435

3536
AREQN = AuthorizationRequest(
3637
response_type="code",
@@ -70,6 +71,10 @@ def full_path(local_file):
7071
"kwargs": {"user": "diana"},
7172
}
7273
},
74+
"userinfo": {
75+
"class": "oidcendpoint.user_info.UserInfo",
76+
"kwargs": {"db": USERS},
77+
},
7378
"client_authn": verify_client,
7479
"template_dir": "template",
7580
"id_token": {"class": IDToken, "kwargs": {"foo": "bar"}},
@@ -252,3 +257,124 @@ def test_get_sign_algorithm_4(self):
252257
)
253258
# default signing alg
254259
assert algs == {"sign": True, "encrypt": False, "sign_alg": "RS512"}
260+
261+
def test_default_claims(self):
262+
session_info = {
263+
"authn_req": AREQN,
264+
"sub": "sub",
265+
"authn_event": {
266+
"authn_info": "loa2",
267+
"authn_time": time.time(),
268+
"uid": "diana"
269+
},
270+
}
271+
self.endpoint_context.idtoken.kwargs['default_claims'] = {
272+
"nickname": {"essential": True}
273+
}
274+
req = {"client_id": "client_1"}
275+
_token = self.endpoint_context.idtoken.make(req, session_info)
276+
assert _token
277+
client_keyjar = KeyJar()
278+
_jwks = self.endpoint_context.keyjar.export_jwks()
279+
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
280+
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
281+
res = _jwt.unpack(_token)
282+
assert "nickname" in res
283+
284+
def test_no_default_claims(self):
285+
session_info = {
286+
"authn_req": AREQN,
287+
"sub": "sub",
288+
"authn_event": {
289+
"authn_info": "loa2",
290+
"authn_time": time.time(),
291+
"uid": "diana"
292+
},
293+
}
294+
req = {"client_id": "client_1"}
295+
_token = self.endpoint_context.idtoken.make(req, session_info)
296+
assert _token
297+
client_keyjar = KeyJar()
298+
_jwks = self.endpoint_context.keyjar.export_jwks()
299+
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
300+
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
301+
res = _jwt.unpack(_token)
302+
assert "nickname" not in res
303+
304+
def test_client_claims(self):
305+
session_info = {
306+
"authn_req": AREQN,
307+
"sub": "sub",
308+
"authn_event": {
309+
"authn_info": "loa2",
310+
"authn_time": time.time(),
311+
"uid": "diana"
312+
},
313+
}
314+
self.endpoint_context.idtoken.enable_claims_per_client = True
315+
self.endpoint_context.cdb["client_1"]['id_token_claims'] = {
316+
"address": None
317+
}
318+
req = {"client_id": "client_1"}
319+
_token = self.endpoint_context.idtoken.make(req, session_info)
320+
assert _token
321+
client_keyjar = KeyJar()
322+
_jwks = self.endpoint_context.keyjar.export_jwks()
323+
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
324+
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
325+
res = _jwt.unpack(_token)
326+
assert "address" in res
327+
assert "nickname" not in res
328+
329+
def test_client_claims_with_default(self):
330+
session_info = {
331+
"authn_req": AREQN,
332+
"sub": "sub",
333+
"authn_event": {
334+
"authn_info": "loa2",
335+
"authn_time": time.time(),
336+
"uid": "diana"
337+
},
338+
}
339+
self.endpoint_context.cdb["client_1"]['id_token_claims'] = {
340+
"address": None
341+
}
342+
self.endpoint_context.idtoken.kwargs['default_claims'] = {
343+
"nickname": {"essential": True}
344+
}
345+
self.endpoint_context.idtoken.enable_claims_per_client = True
346+
req = {"client_id": "client_1"}
347+
_token = self.endpoint_context.idtoken.make(req, session_info)
348+
assert _token
349+
client_keyjar = KeyJar()
350+
_jwks = self.endpoint_context.keyjar.export_jwks()
351+
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
352+
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
353+
res = _jwt.unpack(_token)
354+
assert "address" in res
355+
assert "nickname" in res
356+
357+
def test_client_claims_disabled(self):
358+
# enable_claims_per_client defaults to False
359+
session_info = {
360+
"authn_req": AREQN,
361+
"sub": "sub",
362+
"authn_event": {
363+
"authn_info": "loa2",
364+
"authn_time": time.time(),
365+
"uid": "diana"
366+
},
367+
}
368+
self.endpoint_context.cdb["client_1"]['id_token_claims'] = {
369+
"address": None
370+
}
371+
req = {"client_id": "client_1"}
372+
_token = self.endpoint_context.idtoken.make(req, session_info)
373+
assert _token
374+
client_keyjar = KeyJar()
375+
_jwks = self.endpoint_context.keyjar.export_jwks()
376+
client_keyjar.import_jwks(_jwks, self.endpoint_context.issuer)
377+
_jwt = JWT(key_jar=client_keyjar, iss="client_1")
378+
res = _jwt.unpack(_token)
379+
assert "address" not in res
380+
assert "nickname" not in res

0 commit comments

Comments
 (0)