diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index f2ff014c..93ed7b33 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -18,7 +18,7 @@ from .protocol import ManifestSchema, ToolSchema from .tool import ToolboxTool -from .utils import identify_required_authn_params, resolve_value +from .utils import identify_auth_requirements, resolve_value class ToolboxClient: @@ -79,10 +79,9 @@ def __parse_tool( else: # regular parameter params.append(p) - authn_params, _, used_auth_keys = identify_required_authn_params( - # TODO: Add schema.authRequired as second arg + authn_params, authz_tokens, used_auth_keys = identify_auth_requirements( authn_params, - [], + schema.authRequired, auth_token_getters.keys(), ) @@ -94,6 +93,7 @@ def __parse_tool( # create a read-only values to prevent mutation params=tuple(params), required_authn_params=types.MappingProxyType(authn_params), + required_authz_tokens=authz_tokens, auth_service_token_getters=types.MappingProxyType(auth_token_getters), bound_params=types.MappingProxyType(bound_params), client_headers=types.MappingProxyType(client_headers), diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 0b9c5ce5..a0e5eb2c 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -22,7 +22,7 @@ from .protocol import ParameterSchema from .utils import ( create_func_docstring, - identify_required_authn_params, + identify_auth_requirements, params_to_pydantic_model, resolve_value, ) @@ -49,6 +49,7 @@ def __init__( description: str, params: Sequence[ParameterSchema], required_authn_params: Mapping[str, list[str]], + required_authz_tokens: Sequence[str], auth_service_token_getters: Mapping[str, Callable[[], str]], bound_params: Mapping[str, Union[Callable[[], Any], Any]], client_headers: Mapping[str, Union[Callable, Coroutine, str]], @@ -63,12 +64,14 @@ def __init__( name: The name of the remote tool. description: The description of the remote tool. params: The args of the tool. - required_authn_params: A map of required authenticated parameters to a list - of alternative services that can provide values for them. - auth_service_token_getters: A dict of authService -> token (or callables that - produce a token) - bound_params: A mapping of parameter names to bind to specific values or - callables that are called to produce values as needed. + required_authn_params: A map of required authenticated parameters to + a list of alternative services that can provide values for them. + required_authz_tokens: A sequence of alternative services for + providing authorization token for the tool invocation. + auth_service_token_getters: A dict of authService -> token (or + callables that produce a token) + bound_params: A mapping of parameter names to bind to specific + values or callables that are called to produce values as needed. client_headers: Client specific headers bound to the tool. """ # used to invoke the toolbox API @@ -106,6 +109,8 @@ def __init__( # map of parameter name to auth service required by it self.__required_authn_params = required_authn_params + # sequence of authorization tokens required by it + self.__required_authz_tokens = required_authz_tokens # map of authService -> token_getter self.__auth_service_token_getters = auth_service_token_getters # map of parameter name to value (or callable that produces that value) @@ -149,6 +154,7 @@ def __copy( description: Optional[str] = None, params: Optional[Sequence[ParameterSchema]] = None, required_authn_params: Optional[Mapping[str, list[str]]] = None, + required_authz_tokens: Optional[Sequence[str]] = None, auth_service_token_getters: Optional[Mapping[str, Callable[[], str]]] = None, bound_params: Optional[Mapping[str, Union[Callable[[], Any], Any]]] = None, client_headers: Optional[Mapping[str, Union[Callable, Coroutine, str]]] = None, @@ -162,12 +168,14 @@ def __copy( name: The name of the remote tool. description: The description of the remote tool. params: The args of the tool. - required_authn_params: A map of required authenticated parameters to a list - of alternative services that can provide values for them. - auth_service_token_getters: A dict of authService -> token (or callables - that produce a token) - bound_params: A mapping of parameter names to bind to specific values or - callables that are called to produce values as needed. + required_authn_params: A map of required authenticated parameters to + a list of alternative services that can provide values for them. + required_authz_tokens: A sequence of alternative services for + providing authorization token for the tool invocation. + auth_service_token_getters: A dict of authService -> token (or + callables that produce a token) + bound_params: A mapping of parameter names to bind to specific + values or callables that are called to produce values as needed. client_headers: Client specific headers bound to the tool. """ check = lambda val, default: val if val is not None else default @@ -180,6 +188,9 @@ def __copy( required_authn_params=check( required_authn_params, self.__required_authn_params ), + required_authz_tokens=check( + required_authz_tokens, self.__required_authz_tokens + ), auth_service_token_getters=check( auth_service_token_getters, self.__auth_service_token_getters ), @@ -207,11 +218,15 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str: """ # check if any auth services need to be specified yet - if len(self.__required_authn_params) > 0: + if ( + len(self.__required_authn_params) > 0 + or len(self.__required_authz_tokens) > 0 + ): # Gather all the required auth services into a set req_auth_services = set() for s in self.__required_authn_params.values(): req_auth_services.update(s) + req_auth_services.update(self.__required_authz_tokens) raise ValueError( f"One or more of the following authn services are required to invoke this tool" f": {','.join(req_auth_services)}" @@ -292,23 +307,24 @@ def add_auth_token_getters( f"Cannot register client the same headers in the client as well as tool." ) - # create a read-only updated value for new_getters - new_getters = MappingProxyType( - dict(self.__auth_service_token_getters, **auth_token_getters) - ) - # create a read-only updated for params that are still required - new_req_authn_params = MappingProxyType( - identify_required_authn_params( - # TODO: Add authRequired + new_getters = dict(self.__auth_service_token_getters, **auth_token_getters) + + # find the updated requirements + new_req_authn_params, new_req_authz_tokens, used_auth_token_getters = ( + identify_auth_requirements( self.__required_authn_params, - [], + self.__required_authz_tokens, auth_token_getters.keys(), - )[0] + ) ) + # TODO: Add validation for used_auth_token_getters + return self.__copy( - auth_service_token_getters=new_getters, - required_authn_params=new_req_authn_params, + # create a read-only map for updated getters, params and tokens that are still required + auth_service_token_getters=MappingProxyType(new_getters), + required_authn_params=MappingProxyType(new_req_authn_params), + required_authz_tokens=tuple(new_req_authz_tokens), ) def bind_params( diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index b2954d37..de3b728d 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -44,9 +44,9 @@ def create_func_docstring(description: str, params: Sequence[ParameterSchema]) - return docstring -def identify_required_authn_params( +def identify_auth_requirements( req_authn_params: Mapping[str, list[str]], - req_authz_tokens: list[str], + req_authz_tokens: Sequence[str], auth_service_names: Iterable[str], ) -> tuple[dict[str, list[str]], list[str], set[str]]: """ @@ -100,7 +100,7 @@ def identify_required_authn_params( if matched_authz_services: used_services.update(matched_authz_services) else: - required_authz_tokens = req_authz_tokens + required_authz_tokens = list(req_authz_tokens) return required_authn_params, required_authz_tokens, used_services diff --git a/packages/toolbox-core/tests/test_e2e.py b/packages/toolbox-core/tests/test_e2e.py index 2a3ad34b..c8111b6f 100644 --- a/packages/toolbox-core/tests/test_e2e.py +++ b/packages/toolbox-core/tests/test_e2e.py @@ -148,7 +148,7 @@ async def test_run_tool_no_auth(self, toolbox: ToolboxClient): tool = await toolbox.load_tool("get-row-by-id-auth") with pytest.raises( Exception, - match="tool invocation not authorized. Please make sure your specify correct auth headers", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): await tool(id="2") diff --git a/packages/toolbox-core/tests/test_sync_e2e.py b/packages/toolbox-core/tests/test_sync_e2e.py index 1801ac58..885724e9 100644 --- a/packages/toolbox-core/tests/test_sync_e2e.py +++ b/packages/toolbox-core/tests/test_sync_e2e.py @@ -130,7 +130,7 @@ def test_run_tool_no_auth(self, toolbox: ToolboxSyncClient): tool = toolbox.load_tool("get-row-by-id-auth") with pytest.raises( Exception, - match="tool invocation not authorized. Please make sure your specify correct auth headers", + match="One or more of the following authn services are required to invoke this tool: my-test-auth", ): tool(id="2") diff --git a/packages/toolbox-core/tests/test_tool.py b/packages/toolbox-core/tests/test_tool.py index 03690c08..cbf64efd 100644 --- a/packages/toolbox-core/tests/test_tool.py +++ b/packages/toolbox-core/tests/test_tool.py @@ -205,6 +205,7 @@ async def test_tool_creation_callable_and_run( description=sample_tool_description, params=sample_tool_params, required_authn_params={}, + required_authz_tokens=[], auth_service_token_getters={}, bound_params={}, client_headers={}, @@ -250,6 +251,7 @@ async def test_tool_run_with_pydantic_validation_error( description=sample_tool_description, params=sample_tool_params, required_authn_params={}, + required_authz_tokens=[], auth_service_token_getters={}, bound_params={}, client_headers={}, @@ -337,6 +339,7 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti description=sample_tool_description, params=sample_tool_params, required_authn_params={}, + required_authz_tokens=[], auth_service_token_getters={}, bound_params={}, client_headers={}, @@ -361,6 +364,7 @@ def test_tool_init_with_client_headers( description=sample_tool_description, params=sample_tool_params, required_authn_params={}, + required_authz_tokens=[], auth_service_token_getters={}, bound_params={}, client_headers=static_client_header, @@ -388,6 +392,7 @@ def test_tool_init_header_auth_conflict( description=sample_tool_description, params=sample_tool_auth_params, required_authn_params={}, + required_authz_tokens=[], auth_service_token_getters=auth_getters, bound_params={}, client_headers=conflicting_client_header, @@ -410,6 +415,7 @@ def test_tool_add_auth_token_getters_conflict_with_existing_client_header( description=sample_tool_description, params=sample_tool_params, required_authn_params={}, + required_authz_tokens=[], auth_service_token_getters={}, bound_params={}, client_headers={ diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index 8c41e2e8..c07f44cb 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -23,7 +23,7 @@ from toolbox_core.protocol import ParameterSchema from toolbox_core.utils import ( create_func_docstring, - identify_required_authn_params, + identify_auth_requirements, params_to_pydantic_model, resolve_value, ) @@ -82,7 +82,7 @@ def test_create_func_docstring_empty_description(): assert create_func_docstring(description, params) == expected_docstring -def test_identify_required_authn_params_none_required(): +def test_identify_auth_requirements_none_required(): """Test when no authentication parameters or authorization tokens are required initially.""" req_authn_params: dict[str, list[str]] = {} req_authz_tokens: list[str] = [] @@ -90,7 +90,7 @@ def test_identify_required_authn_params_none_required(): expected_params = {} expected_authz: list[str] = [] expected_used = set() - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ( @@ -100,7 +100,7 @@ def test_identify_required_authn_params_none_required(): ) -def test_identify_required_authn_params_all_covered(): +def test_identify_auth_requirements_all_covered(): """Test when all required authn parameters are covered, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], @@ -111,7 +111,7 @@ def test_identify_required_authn_params_all_covered(): expected_params = {} expected_authz: list[str] = [] expected_used = {"service_a", "service_b"} - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ( @@ -121,7 +121,7 @@ def test_identify_required_authn_params_all_covered(): ) -def test_identify_required_authn_params_some_covered(): +def test_identify_auth_requirements_some_covered(): """Test when some authn parameters are covered, and some are not, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], @@ -138,7 +138,7 @@ def test_identify_required_authn_params_some_covered(): expected_authz: list[str] = [] expected_used = {"service_a", "service_b"} - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ( @@ -148,7 +148,7 @@ def test_identify_required_authn_params_some_covered(): ) -def test_identify_required_authn_params_none_covered(): +def test_identify_auth_requirements_none_covered(): """Test when none of the required authn parameters are covered, no authz tokens.""" req_authn_params = { "token_d": ["service_d"], @@ -162,7 +162,7 @@ def test_identify_required_authn_params_none_covered(): } expected_authz: list[str] = [] expected_used = set() - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ( @@ -172,7 +172,7 @@ def test_identify_required_authn_params_none_covered(): ) -def test_identify_required_authn_params_no_available_services(): +def test_identify_auth_requirements_no_available_services(): """Test when no authn services are available, no authz tokens.""" req_authn_params = { "token_a": ["service_a"], @@ -186,7 +186,7 @@ def test_identify_required_authn_params_no_available_services(): } expected_authz: list[str] = [] expected_used = set() - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ( @@ -196,7 +196,7 @@ def test_identify_required_authn_params_no_available_services(): ) -def test_identify_required_authn_params_empty_services_for_param(): +def test_identify_auth_requirements_empty_services_for_param(): """Test edge case: authn param requires an empty list of services, no authz tokens.""" req_authn_params = { "token_x": [], @@ -208,7 +208,7 @@ def test_identify_required_authn_params_empty_services_for_param(): } expected_authz: list[str] = [] expected_used = set() - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ( @@ -223,7 +223,7 @@ def test_identify_auth_params_only_authz_empty(): req_authn_params: dict[str, list[str]] = {} req_authz_tokens: list[str] = [] auth_service_names = ["s1"] - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ({}, [], set()) @@ -234,7 +234,7 @@ def test_identify_auth_params_authz_all_covered(): req_authn_params: dict[str, list[str]] = {} req_authz_tokens = ["s1", "s2"] auth_service_names = ["s1", "s2", "s3"] - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ({}, [], {"s1", "s2"}) @@ -245,7 +245,7 @@ def test_identify_auth_params_authz_partially_covered_by_available(): req_authn_params: dict[str, list[str]] = {} req_authz_tokens = ["s1", "s2"] auth_service_names = ["s1", "s3"] - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ({}, [], {"s1"}) @@ -256,7 +256,7 @@ def test_identify_auth_params_authz_none_covered(): req_authn_params: dict[str, list[str]] = {} req_authz_tokens = ["s1", "s2"] auth_service_names = ["s3", "s4"] - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ({}, ["s1", "s2"], set()) @@ -267,7 +267,7 @@ def test_identify_auth_params_authz_none_covered_empty_available(): req_authn_params: dict[str, list[str]] = {} req_authz_tokens = ["s1", "s2"] auth_service_names: list[str] = [] - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == ({}, ["s1", "s2"], set()) @@ -281,7 +281,7 @@ def test_identify_auth_params_authn_covered_authz_uncovered(): expected_params = {} expected_authz: list[str] = ["s_authz_needed1", "s_authz_needed2"] expected_used = {"s_authn1"} - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == (expected_params, expected_authz, expected_used) @@ -296,7 +296,7 @@ def test_identify_auth_params_authn_uncovered_authz_covered(): expected_authz: list[str] = [] expected_used = {"s_authz1"} - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == (expected_params, expected_authz, expected_used) @@ -310,7 +310,7 @@ def test_identify_auth_params_authn_and_authz_covered_no_overlap(): expected_params = {} expected_authz: list[str] = [] expected_used = {"s_authn1", "s_authz1"} - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == (expected_params, expected_authz, expected_used) @@ -328,7 +328,7 @@ def test_identify_auth_params_authn_and_authz_covered_with_overlap(): expected_params = {} expected_authz: list[str] = [] expected_used = {"s_common", "s_authz_specific_avail", "s_authn_specific_avail"} - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == (expected_params, expected_authz, expected_used) @@ -346,7 +346,7 @@ def test_identify_auth_params_authn_and_authz_covered_with_overlap_same_param(): expected_params = {} expected_authz: list[str] = [] expected_used = {"s_common", "s_authz_specific_avail", "s_authn_specific_avail"} - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == (expected_params, expected_authz, expected_used) @@ -360,7 +360,7 @@ def test_identify_auth_params_complex_scenario(): expected_params = {"p2": ["s3"]} expected_authz: list[str] = [] expected_used = {"s1", "s4"} - result = identify_required_authn_params( + result = identify_auth_requirements( req_authn_params, req_authz_tokens, auth_service_names ) assert result == (