|
1 | 1 | import re |
| 2 | +import logging |
2 | 3 |
|
3 | 4 | from .base import ResponseMicroService |
| 5 | +from ..context import Context |
| 6 | +from ..exception import SATOSAError |
4 | 7 |
|
| 8 | +logger = logging.getLogger(__name__) |
5 | 9 |
|
6 | 10 | class AddStaticAttributes(ResponseMicroService): |
7 | 11 | """ |
@@ -29,28 +33,62 @@ def __init__(self, config, *args, **kwargs): |
29 | 33 | def process(self, context, data): |
30 | 34 | # apply default filters |
31 | 35 | provider_filters = self.attribute_filters.get("", {}) |
32 | | - self._apply_requester_filters(data.attributes, provider_filters, data.requester) |
| 36 | + target_provider = data.auth_info.issuer |
| 37 | + self._apply_requester_filters(data.attributes, provider_filters, data.requester, context, target_provider) |
33 | 38 |
|
34 | 39 | # apply target provider specific filters |
35 | | - target_provider = data.auth_info.issuer |
36 | 40 | provider_filters = self.attribute_filters.get(target_provider, {}) |
37 | | - self._apply_requester_filters(data.attributes, provider_filters, data.requester) |
| 41 | + self._apply_requester_filters(data.attributes, provider_filters, data.requester, context, target_provider) |
38 | 42 | return super().process(context, data) |
39 | 43 |
|
40 | | - def _apply_requester_filters(self, attributes, provider_filters, requester): |
| 44 | + def _apply_requester_filters(self, attributes, provider_filters, requester, context, target_provider): |
41 | 45 | # apply default requester filters |
42 | 46 | default_requester_filters = provider_filters.get("", {}) |
43 | | - self._apply_filter(attributes, default_requester_filters) |
| 47 | + self._apply_filters(attributes, default_requester_filters, context, target_provider) |
44 | 48 |
|
45 | 49 | # apply requester specific filters |
46 | 50 | requester_filters = provider_filters.get(requester, {}) |
47 | | - self._apply_filter(attributes, requester_filters) |
48 | | - |
49 | | - def _apply_filter(self, attributes, attribute_filters): |
50 | | - for attribute_name, attribute_filter in attribute_filters.items(): |
51 | | - regex = re.compile(attribute_filter) |
52 | | - if attribute_name == "": # default filter for all attributes |
53 | | - for attribute, values in attributes.items(): |
54 | | - attributes[attribute] = list(filter(regex.search, attributes[attribute])) |
55 | | - elif attribute_name in attributes: |
56 | | - attributes[attribute_name] = list(filter(regex.search, attributes[attribute_name])) |
| 51 | + self._apply_filters(attributes, requester_filters, context, target_provider) |
| 52 | + |
| 53 | + def _apply_filters(self, attributes, attribute_filters, context, target_provider): |
| 54 | + for attribute_name, attribute_filters in attribute_filters.items(): |
| 55 | + if type(attribute_filters) == str: |
| 56 | + # convert simple notation to filter list |
| 57 | + attribute_filters = {'regexp': attribute_filters} |
| 58 | + |
| 59 | + for filter_type, filter_value in attribute_filters.items(): |
| 60 | + |
| 61 | + if filter_type == "regexp": |
| 62 | + filter_func = re.compile(filter_value).search |
| 63 | + elif filter_type == "shibmdscope_match_scope": |
| 64 | + mdstore = context.get_decoration(Context.KEY_METADATA_STORE) |
| 65 | + md_scopes = list(mdstore.shibmd_scopes(target_provider,"idpsso_descriptor")) if mdstore else [] |
| 66 | + filter_func = lambda v: self._shibmdscope_match_scope(v, md_scopes) |
| 67 | + elif filter_type == "shibmdscope_match_value": |
| 68 | + mdstore = context.get_decoration(Context.KEY_METADATA_STORE) |
| 69 | + md_scopes = list(mdstore.shibmd_scopes(target_provider,"idpsso_descriptor")) if mdstore else [] |
| 70 | + filter_func = lambda v: self._shibmdscope_match_value(v, md_scopes) |
| 71 | + else: |
| 72 | + raise SATOSAError("Unknown filter type") |
| 73 | + |
| 74 | + if attribute_name == "": # default filter for all attributes |
| 75 | + for attribute, values in attributes.items(): |
| 76 | + attributes[attribute] = list(filter(filter_func, attributes[attribute])) |
| 77 | + elif attribute_name in attributes: |
| 78 | + attributes[attribute_name] = list(filter(filter_func, attributes[attribute_name])) |
| 79 | + |
| 80 | + def _shibmdscope_match_value(self, value, md_scopes): |
| 81 | + for md_scope in md_scopes: |
| 82 | + if not md_scope['regexp'] and md_scope['text'] == value: |
| 83 | + return True |
| 84 | + elif md_scope['regexp'] and re.fullmatch(md_scope['text'], value): |
| 85 | + return True |
| 86 | + return False |
| 87 | + |
| 88 | + def _shibmdscope_match_scope(self, value, md_scopes): |
| 89 | + split_value = value.split('@') |
| 90 | + if len(split_value) != 2: |
| 91 | + logger.info(f"Discarding invalid scoped value {value}") |
| 92 | + return False |
| 93 | + value_scope = split_value[1] |
| 94 | + return self._shibmdscope_match_value(value_scope, md_scopes) |
0 commit comments