Skip to content

Commit 2aba2aa

Browse files
authored
Merge pull request #51 from ctriant/userinfo-policy
Introduce userinfo policy
2 parents 4eb8214 + 981dc40 commit 2aba2aa

File tree

3 files changed

+119
-8
lines changed

3 files changed

+119
-8
lines changed

doc/server/contents/conf.rst

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,11 @@ An example::
408408
"normal",
409409
"aggregated",
410410
"distributed"
411-
]
411+
],
412+
"policy": {
413+
"function": "/path/to/callable",
414+
"kwargs": {}
415+
}
412416
}
413417
},
414418
"revocation": {
@@ -747,6 +751,10 @@ the following::
747751
"userinfo": {
748752
"class": "oidc_provider.users.UserInfo",
749753
"kwargs": {
754+
"policy": {
755+
"function": "/path/to/callable",
756+
"kwargs": {}
757+
},
750758
"claims_map": {
751759
"phone_number": "telephone",
752760
"family_name": "last_name",
@@ -760,6 +768,17 @@ the following::
760768
}
761769
}
762770

771+
The policy for userinfo endpoint is optional and can also be configured in a client's metadata, for example::
772+
773+
"userinfo": {
774+
"kwargs": {
775+
"policy": {
776+
"function": "/path/to/callable",
777+
"kwargs": {}
778+
}
779+
}
780+
}
781+
763782
================================
764783
Special Configuration directives
765784
================================

src/idpyoidc/server/oidc/userinfo.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from cryptojwt.jwt import utc_time_sans_frac
1111

1212
from idpyoidc import claims
13+
from idpyoidc.util import importer
1314
from idpyoidc.message import Message
1415
from idpyoidc.message import oidc
1516
from idpyoidc.message.oauth2 import ResponseMessage
1617
from idpyoidc.server.endpoint import Endpoint
1718
from idpyoidc.server.exception import ClientAuthenticationError
19+
from idpyoidc.exception import ImproperlyConfigured
1820
from idpyoidc.server.util import OAUTH2_NOCACHE_HEADERS
1921

2022
logger = logging.getLogger(__name__)
@@ -46,18 +48,28 @@ def __init__(self, upstream_get: Callable, add_claims_by_scope: Optional[bool] =
4648
# Add the issuer ID as an allowed JWT target
4749
self.allowed_targets.append("")
4850

49-
def get_client_id_from_token(self, context, token, request=None):
50-
_info = context.session_manager.get_session_info_by_token(
51+
if kwargs is None:
52+
self.config = {
53+
"policy": {
54+
"function": "/path/to/callable",
55+
"kwargs": {}
56+
},
57+
}
58+
else:
59+
self.config = kwargs
60+
61+
def get_client_id_from_token(self, endpoint_context, token, request=None):
62+
_info = endpoint_context.session_manager.get_session_info_by_token(
5163
token, handler_key="access_token"
5264
)
5365
return _info["client_id"]
5466

5567
def do_response(
56-
self,
57-
response_args: Optional[Union[Message, dict]] = None,
58-
request: Optional[Union[Message, dict]] = None,
59-
client_id: Optional[str] = "",
60-
**kwargs
68+
self,
69+
response_args: Optional[Union[Message, dict]] = None,
70+
request: Optional[Union[Message, dict]] = None,
71+
client_id: Optional[str] = "",
72+
**kwargs
6173
) -> dict:
6274

6375
if "error" in kwargs and kwargs["error"]:
@@ -157,6 +169,12 @@ def process_request(self, request=None, **kwargs):
157169
info["sub"] = _grant.sub
158170
if _grant.add_acr_value("userinfo"):
159171
info["acr"] = _grant.authentication_event["authn_info"]
172+
173+
if "userinfo" in _cntxt.cdb[request["client_id"]]:
174+
self.config["policy"] = _cntxt.cdb[request["client_id"]]["userinfo"]["policy"]
175+
176+
if "policy" in self.config:
177+
info = self._enforce_policy(request, info, token, self.config)
160178
else:
161179
info = {
162180
"error": "invalid_request",
@@ -190,3 +208,26 @@ def parse_request(self, request, http_info=None, **kwargs):
190208
request["access_token"] = auth_info["token"]
191209

192210
return request
211+
212+
def _enforce_policy(self, request, response_info, token, config):
213+
policy = config["policy"]
214+
callable = policy["function"]
215+
kwargs = policy.get("kwargs", {})
216+
217+
if isinstance(callable, str):
218+
try:
219+
fn = importer(callable)
220+
except Exception:
221+
raise ImproperlyConfigured(f"Error importing {callable} policy callable")
222+
else:
223+
fn = callable
224+
225+
try:
226+
return fn(request, token, response_info, **kwargs)
227+
except Exception as e:
228+
logger.error(f"Error while executing the {fn} policy callable: {e}")
229+
return self.error_cls(error="server_error", error_description="Internal server error")
230+
231+
232+
def validate_userinfo_policy(request, token, response_info, **kwargs):
233+
return response_info

tests/test_server_26_oidc_userinfo_endpoint.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
from idpyoidc.server.scopes import SCOPE2CLAIMS
2222
from idpyoidc.server.user_authn.authn_context import INTERNETPROTOCOLPASSWORD
2323
from idpyoidc.server.user_info import UserInfo
24+
from idpyoidc.server.oidc.userinfo import validate_userinfo_policy
2425
from idpyoidc.time_util import utc_time_sans_frac
2526
from tests import CRYPT_CONFIG
2627
from tests import SESSION_PARAMS
2728

29+
2830
KEYDEFS = [
2931
{"type": "RSA", "key": "", "use": ["sig"]},
3032
{"type": "EC", "crv": "P-256", "use": ["sig"]},
@@ -637,3 +639,52 @@ def test_process_request_absent_userinfo_conf(self):
637639

638640
with pytest.raises(ImproperlyConfigured):
639641
code = self._mint_code(grant, session_id)
642+
643+
def test_userinfo_policy(self):
644+
_auth_req = AUTH_REQ.copy()
645+
646+
session_id = self._create_session(_auth_req)
647+
grant = self.session_manager[session_id]
648+
access_token = self._mint_token("access_token", grant, session_id)
649+
650+
http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}}
651+
652+
def _custom_validate_userinfo_policy(request, token, response_info, **kwargs):
653+
return {"custom": "policy"}
654+
655+
self.endpoint.config["policy"] = {}
656+
self.endpoint.config["policy"]["function"] = _custom_validate_userinfo_policy
657+
658+
_req = self.endpoint.parse_request({}, http_info=http_info)
659+
args = self.endpoint.process_request(_req)
660+
assert args
661+
res = self.endpoint.do_response(request=_req, **args)
662+
_response = json.loads(res["response"])
663+
assert "custom" in _response
664+
665+
def test_userinfo_policy_per_client(self):
666+
_auth_req = AUTH_REQ.copy()
667+
668+
session_id = self._create_session(_auth_req)
669+
grant = self.session_manager[session_id]
670+
access_token = self._mint_token("access_token", grant, session_id)
671+
672+
http_info = {"headers": {"authorization": "Bearer {}".format(access_token.value)}}
673+
674+
def _custom_validate_userinfo_policy(request, token, response_info, **kwargs):
675+
return {"custom": "policy"}
676+
677+
self.context.cdb["client_1"]["userinfo"] = {
678+
"policy": {
679+
"function": _custom_validate_userinfo_policy,
680+
"kwargs": {}
681+
}
682+
}
683+
684+
_req = self.endpoint.parse_request({}, http_info=http_info)
685+
args = self.endpoint.process_request(_req)
686+
assert args
687+
res = self.endpoint.do_response(request=_req, **args)
688+
_response = json.loads(res["response"])
689+
assert "custom" in _response
690+

0 commit comments

Comments
 (0)