diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index e59a934b..a5082163 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -79,7 +79,7 @@ def __parse_tool( else: # regular parameter params.append(p) - authn_params = identify_required_authn_params( + authn_params, _ = identify_required_authn_params( authn_params, auth_token_getters.keys() ) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index ba006e5d..8029d396 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -293,7 +293,7 @@ def add_auth_token_getters( new_req_authn_params = MappingProxyType( identify_required_authn_params( self.__required_authn_params, auth_token_getters.keys() - ) + )[0] ) return self.__copy( diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index 3612d00e..a30dc66b 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -46,29 +46,40 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) - def identify_required_authn_params( req_authn_params: Mapping[str, list[str]], auth_service_names: Iterable[str] -) -> dict[str, list[str]]: +) -> tuple[dict[str, list[str]], set[str]]: """ Identifies authentication parameters that are still required; because they - are not covered by the provided `auth_service_names`. + are not covered by the provided `auth_service_names`, and also returns a + set of all authentication services that were found to be matching. Args: - req_authn_params: A mapping of parameter names to sets of required + req_authn_params: A mapping of parameter names to lists of required authentication services. auth_service_names: An iterable of authentication service names for which token getters are available. Returns: - A new dictionary representing the subset of required authentication parameters - that are not covered by the provided `auth_service_names`. + A tuple containing: + - A new dictionary representing the subset of required + authentication parameters that are not covered by the provided + `auth_service_names`. + - A list of authentication service names from `auth_service_names` + that were found to satisfy at least one parameter's requirements. """ - required_params = {} # params that are still required with provided auth_services + required_params: dict[str, list[str]] = {} + used_services: set[str] = set() + for param, services in req_authn_params.items(): # if we don't have a token_getter for any of the services required by the param, # the param is still required - required = not any(s in services for s in auth_service_names) - if required: + matched_services = [s for s in services if s in auth_service_names] + + if matched_services: + used_services.update(matched_services) + else: required_params[param] = services - return required_params + + return required_params, used_services def params_to_pydantic_model( diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index b71284b6..52cdb38f 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -87,8 +87,10 @@ def test_identify_required_authn_params_none_required(): req_authn_params = {} auth_service_names = ["service_a", "service_b"] expected = {} - assert ( - identify_required_authn_params(req_authn_params, auth_service_names) == expected + expected_used = set() + assert identify_required_authn_params(req_authn_params, auth_service_names) == ( + expected, + expected_used, ) @@ -100,8 +102,10 @@ def test_identify_required_authn_params_all_covered(): } auth_service_names = ["service_a", "service_b"] expected = {} - assert ( - identify_required_authn_params(req_authn_params, auth_service_names) == expected + expected_used = set(auth_service_names) + assert identify_required_authn_params(req_authn_params, auth_service_names) == ( + expected, + expected_used, ) @@ -118,8 +122,10 @@ def test_identify_required_authn_params_some_covered(): "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } - assert ( - identify_required_authn_params(req_authn_params, auth_service_names) == expected + expected_used = set(auth_service_names) + assert identify_required_authn_params(req_authn_params, auth_service_names) == ( + expected, + expected_used, ) @@ -134,8 +140,10 @@ def test_identify_required_authn_params_none_covered(): "token_d": ["service_d"], "token_e": ["service_e", "service_f"], } - assert ( - identify_required_authn_params(req_authn_params, auth_service_names) == expected + expected_used = set() + assert identify_required_authn_params(req_authn_params, auth_service_names) == ( + expected, + expected_used, ) @@ -150,8 +158,10 @@ def test_identify_required_authn_params_no_available_services(): "token_a": ["service_a"], "token_b": ["service_b", "service_c"], } - assert ( - identify_required_authn_params(req_authn_params, auth_service_names) == expected + expected_used = set() + assert identify_required_authn_params(req_authn_params, auth_service_names) == ( + expected, + expected_used, ) @@ -164,8 +174,10 @@ def test_identify_required_authn_params_empty_services_for_param(): expected = { "token_x": [], } - assert ( - identify_required_authn_params(req_authn_params, auth_service_names) == expected + expected_used = set() + assert identify_required_authn_params(req_authn_params, auth_service_names) == ( + expected, + expected_used, )