Skip to content

Commit e7ad982

Browse files
committed
Fix DecideBackendByTargetIdP and introduce DecideBackendByDiscoIdP
Signed-off-by: Ivan Kanakarakis <[email protected]>
1 parent d1784a7 commit e7ad982

File tree

2 files changed

+130
-111
lines changed

2 files changed

+130
-111
lines changed

src/satosa/micro_services/custom_routing.py

Lines changed: 46 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,55 @@ class CustomRoutingError(SATOSAError):
2222

2323
class DecideBackendByTargetIdP(RequestMicroService):
2424
"""
25-
Select which backend should be used based on who is the SAML IDP
25+
Select target backend based on the target issuer.
2626
"""
2727

2828
def __init__(self, config:dict, *args, **kwargs):
2929
"""
3030
Constructor.
31+
3132
:param config: microservice configuration loaded from yaml file
3233
:type config: Dict[str, Dict[str, str]]
3334
"""
3435
super().__init__(*args, **kwargs)
36+
3537
self.target_mapping = config['target_mapping']
36-
self.endpoint_paths = config['endpoint_paths']
3738
self.default_backend = config['default_backend']
3839

39-
if not isinstance(self.endpoint_paths, list):
40-
raise SATOSAConfigurationError()
40+
def process(self, context:Context, data:InternalData):
41+
"""
42+
Set context.target_backend based on the target issuer (context.target_entity_id)
43+
44+
:param context: request context
45+
:param data: the internal request
46+
"""
47+
target_issuer = context.get_decoration(Context.KEY_TARGET_ENTITYID)
48+
if not target_issuer:
49+
return super().process(context, data)
50+
51+
target_backend = (
52+
self.target_mapping.get(target_issuer)
53+
or self.default_backend
54+
)
55+
56+
report = {
57+
'msg': 'decided target backend by target issuer',
58+
'target_issuer': target_issuer,
59+
'target_backend': target_backend,
60+
}
61+
logger.info(report)
62+
63+
context.target_backend = target_backend
64+
return super().process(context, data)
65+
66+
67+
class DecideBackendByDiscoIdP(DecideBackendByTargetIdP):
68+
def __init__(self, config:dict, *args, **kwargs):
69+
super().__init__(config, *args, **kwargs)
70+
71+
self.disco_endpoints = config['disco_endpoints']
72+
if not isinstance(self.disco_endpoints, list):
73+
raise CustomRoutingError('disco_endpoints must be a list of str')
4174

4275
def register_endpoints(self):
4376
"""
@@ -54,69 +87,20 @@ def register_endpoints(self):
5487
[(regexp, Callable[[satosa.context.Context], satosa.response.Response]), ...]
5588
"""
5689

57-
# this intercepts disco response
5890
return [
59-
(path , self.backend_by_entityid)
60-
for path in self.endpoint_paths
91+
(path , self._handle_disco_response)
92+
for path in self.disco_endpoints
6193
]
6294

63-
def _get_request_entity_id(self, context):
64-
return (
65-
context.get_decoration(Context.KEY_TARGET_ENTITYID) or
66-
context.request.get('entityID')
67-
)
68-
69-
def _get_backend(self, context:Context, entity_id:str) -> str:
70-
"""
71-
returns the Target Backend to use
72-
"""
73-
return (
74-
self.target_mapping.get(entity_id) or
75-
self.default_backend
76-
)
77-
78-
def process(self, context:Context, data:dict):
79-
"""
80-
Will modify the context.target_backend attribute based on the target entityid.
81-
:param context: request context
82-
:param data: the internal request
83-
"""
84-
entity_id = self._get_request_entity_id(context)
85-
if entity_id:
86-
self._rewrite_context(entity_id, context)
87-
return super().process(context, data)
88-
89-
def _rewrite_context(self, entity_id:str, context:Context) -> None:
90-
tr_backend = self._get_backend(context, entity_id)
91-
context.decorate(Context.KEY_TARGET_ENTITYID, entity_id)
92-
context.target_frontend = context.target_frontend or context.state.get('ROUTER')
93-
native_backend = context.target_backend
94-
msg = (f'Found DecideBackendByTarget ({self.name} microservice) '
95-
f'redirecting {entity_id} from {native_backend} '
96-
f'backend to {tr_backend}')
97-
logger.info(msg)
98-
context.target_backend = tr_backend
99-
100-
def backend_by_entityid(self, context:Context):
101-
entity_id = self._get_request_entity_id(context)
102-
103-
if entity_id:
104-
self._rewrite_context(entity_id, context)
105-
else:
106-
raise CustomRoutingError(
107-
f"{self.__class__.__name__} "
108-
"can't find any valid entity_id in the context."
109-
)
110-
111-
if not context.state.get('ROUTER'):
112-
raise SATOSAStateError(
113-
f"{self.__class__.__name__} "
114-
"can't find any valid state in the context."
115-
)
95+
def _handle_disco_response(self, context:Context):
96+
target_issuer_from_disco = context.request.get('entityID')
97+
if not target_issuer_from_disco:
98+
raise CustomRoutingError('no valid entity_id in the disco response')
11699

117-
data_serialized = context.state.get(self.name, {}).get("internal", {})
100+
context.decorate(Context.KEY_TARGET_ENTITYID, target_issuer_from_disco)
101+
data_serialized = context.state.get(self.name, {}).get('internal', {})
118102
data = InternalData.from_dict(data_serialized)
119-
return super().process(context, data)
103+
return self.process(context, data)
120104

121105

122106
class DecideBackendByRequester(RequestMicroService):

tests/satosa/micro_services/test_custom_routing.py

Lines changed: 84 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from base64 import urlsafe_b64encode
2+
from unittest import TestCase
23

34
import pytest
45

56
from satosa.context import Context
7+
from satosa.state import State
68
from satosa.exception import SATOSAError, SATOSAConfigurationError, SATOSAStateError
79
from satosa.internal import InternalData
810
from satosa.micro_services.custom_routing import DecideIfRequesterIsAllowed
11+
from satosa.micro_services.custom_routing import DecideBackendByDiscoIdP
912
from satosa.micro_services.custom_routing import DecideBackendByTargetIdP
1013
from satosa.micro_services.custom_routing import CustomRoutingError
1114

15+
1216
TARGET_ENTITY = "entity1"
1317

1418

@@ -160,61 +164,92 @@ def test_missing_target_entity_id_from_context(self, context):
160164
decide_service.process(context, req)
161165

162166

163-
class TestDecideBackendByTargetIdP:
164-
rules = {
165-
'default_backend': 'Saml2',
166-
'endpoint_paths': ['.*/disco'],
167-
'target_mapping': {'http://idpspid.testunical.it:8088': 'spidSaml2'}
168-
}
169-
170-
def create_decide_service(self, rules):
171-
decide_service = DecideBackendByTargetIdP(
172-
config=rules,
173-
name="test_decide_service",
174-
base_url="https://satosa.example.com"
175-
)
176-
decide_service.next = lambda ctx, data: data
177-
return decide_service
167+
class TestDecideBackendByTargetIdP(TestCase):
168+
def setUp(self):
169+
context = Context()
170+
context.state = State()
178171

179-
180-
def test_missing_state(self, target_context):
181-
decide_service = self.create_decide_service(self.rules)
182-
target_context.request = {
183-
'entityID': 'http://idpspid.testunical.it:8088',
172+
config = {
173+
'default_backend': 'default_backend',
174+
'target_mapping': {
175+
'mapped_idp.example.org': 'mapped_backend',
176+
},
177+
'disco_endpoints': [
178+
'.*/disco',
179+
],
184180
}
185-
req = InternalData(requester="test_requester")
186-
req.requester = "somebody else"
187-
assert decide_service.process(target_context, req)
188-
189-
with pytest.raises(SATOSAStateError):
190-
decide_service.backend_by_entityid(target_context)
191181

192-
193-
def test_unmatching_target(self, target_context):
194-
"""
195-
It would rely on the default backend
196-
"""
197-
decide_service = self.create_decide_service(self.rules)
198-
target_context.request = {
199-
'entityID': 'unknow-entity-id',
182+
plugin = DecideBackendByTargetIdP(
183+
config=config,
184+
name='test_decide_service',
185+
base_url='https://satosa.example.org',
186+
)
187+
plugin.next = lambda ctx, data: (ctx, data)
188+
189+
self.config = config
190+
self.context = context
191+
self.plugin = plugin
192+
193+
def test_when_target_is_not_set_do_skip(self):
194+
data = InternalData(requester='test_requester')
195+
newctx, newdata = self.plugin.process(self.context, data)
196+
assert not newctx.target_backend
197+
198+
def test_when_target_is_not_mapped_choose_default_backend(self):
199+
self.context.decorate(Context.KEY_TARGET_ENTITYID, 'idp.example.org')
200+
data = InternalData(requester='test_requester')
201+
newctx, newdata = self.plugin.process(self.context, data)
202+
assert newctx.target_backend == 'default_backend'
203+
204+
def test_when_target_is_mapped_choose_mapping_backend(self):
205+
self.context.decorate(Context.KEY_TARGET_ENTITYID, 'mapped_idp.example.org')
206+
data = InternalData(requester='test_requester')
207+
data.requester = 'somebody else'
208+
newctx, newdata = self.plugin.process(self.context, data)
209+
assert newctx.target_backend == 'mapped_backend'
210+
211+
212+
class TestDecideBackendByDiscoIdP(TestCase):
213+
def setUp(self):
214+
context = Context()
215+
context.state = State()
216+
217+
config = {
218+
'default_backend': 'default_backend',
219+
'target_mapping': {
220+
'mapped_idp.example.org': 'mapped_backend',
221+
},
222+
'disco_endpoints': [
223+
'.*/disco',
224+
],
200225
}
201-
target_context.state['ROUTER'] = 'Saml2'
202226

203-
req = InternalData(requester="test_requester")
204-
assert decide_service.process(target_context, req)
227+
plugin = DecideBackendByDiscoIdP(
228+
config=config,
229+
name='test_decide_service',
230+
base_url='https://satosa.example.org',
231+
)
232+
plugin.next = lambda ctx, data: (ctx, data)
205233

206-
res = decide_service.backend_by_entityid(target_context)
207-
assert isinstance(res, InternalData)
234+
self.config = config
235+
self.context = context
236+
self.plugin = plugin
208237

209-
def test_matching_target(self, target_context):
210-
decide_service = self.create_decide_service(self.rules)
211-
target_context.request = {
212-
'entityID': 'http://idpspid.testunical.it:8088-entity-id'
238+
def test_when_target_is_not_set_raise_error(self):
239+
self.context.request = {}
240+
with pytest.raises(CustomRoutingError):
241+
self.plugin._handle_disco_response(self.context)
242+
243+
def test_when_target_is_not_mapped_choose_default_backend(self):
244+
self.context.request = {
245+
'entityID': 'idp.example.org',
213246
}
214-
target_context.state['ROUTER'] = 'Saml2'
247+
newctx, newdata = self.plugin._handle_disco_response(self.context)
248+
assert newctx.target_backend == 'default_backend'
215249

216-
req = InternalData(requester="test_requester")
217-
req.requester = "somebody else"
218-
assert decide_service.process(target_context, req)
219-
res = decide_service.backend_by_entityid(target_context)
220-
assert isinstance(res, InternalData)
250+
def test_when_target_is_mapped_choose_mapping_backend(self):
251+
self.context.request = {
252+
'entityID': 'mapped_idp.example.org',
253+
}
254+
newctx, newdata = self.plugin._handle_disco_response(self.context)
255+
assert newctx.target_backend == 'mapped_backend'

0 commit comments

Comments
 (0)