Skip to content

Commit a10964f

Browse files
authored
chore(toolbox-core): Consolidate auth header creation logic (#213)
* chore: Add unit test coverage. * chore: Consolidate auth header creation logic Post adding the feature of adding client-level auth headers (#178), we have the logic for creating an auth header, from the given auth token getter name, in 3 different places. This PR unifies all of that logic into a single helper to improve maintenance, and make it easier to change the way we add suffix/prefix, and reduces WET code. * chore: Delint * chore: Delint
1 parent a2810c1 commit a10964f

File tree

2 files changed

+141
-3
lines changed

2 files changed

+141
-3
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
# Validate conflicting Headers/Auth Tokens
9595
request_header_names = client_headers.keys()
9696
auth_token_names = [
97-
auth_token_name + "_token"
97+
self.__get_auth_header(auth_token_name)
9898
for auth_token_name in auth_service_token_getters.keys()
9999
]
100100
duplicates = request_header_names & auth_token_names
@@ -187,6 +187,10 @@ def __copy(
187187
client_headers=check(client_headers, self.__client_headers),
188188
)
189189

190+
def __get_auth_header(self, auth_token_name: str) -> str:
191+
"""Returns the formatted auth token header name."""
192+
return f"{auth_token_name}_token"
193+
190194
async def __call__(self, *args: Any, **kwargs: Any) -> str:
191195
"""
192196
Asynchronously calls the remote tool with the provided arguments.
@@ -228,7 +232,9 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
228232
# create headers for auth services
229233
headers = {}
230234
for auth_service, token_getter in self.__auth_service_token_getters.items():
231-
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
235+
headers[self.__get_auth_header(auth_service)] = await resolve_value(
236+
token_getter
237+
)
232238
for client_header_name, client_header_val in self.__client_headers.items():
233239
headers[client_header_name] = await resolve_value(client_header_val)
234240

@@ -276,7 +282,8 @@ def add_auth_token_getters(
276282
# Validate duplicates with client headers
277283
request_header_names = self.__client_headers.keys()
278284
auth_token_names = [
279-
auth_token_name + "_token" for auth_token_name in incoming_services
285+
self.__get_auth_header(auth_token_name)
286+
for auth_token_name in incoming_services
280287
]
281288
duplicates = request_header_names & auth_token_names
282289
if duplicates:

packages/toolbox-core/tests/test_client.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,137 @@ async def test_bind_param_async_callable_value_success(self, tool_name, client):
696696
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
697697
bound_async_callable.assert_awaited_once()
698698

699+
@pytest.mark.asyncio
700+
async def test_bind_param_success(self, tool_name, client):
701+
"""Tests 'bind_param' with a bound parameter specified."""
702+
tool = await client.load_tool(tool_name)
703+
704+
assert len(tool.__signature__.parameters) == 2
705+
assert "argA" in tool.__signature__.parameters
706+
707+
tool = tool.bind_param("argA", 5)
708+
709+
assert len(tool.__signature__.parameters) == 1
710+
assert "argA" not in tool.__signature__.parameters
711+
712+
res = await tool(True)
713+
assert "argA" in res
714+
715+
@pytest.mark.asyncio
716+
async def test_bind_callable_param_success(self, tool_name, client):
717+
"""Tests 'bind_param' with a bound parameter specified."""
718+
tool = await client.load_tool(tool_name)
719+
720+
assert len(tool.__signature__.parameters) == 2
721+
assert "argA" in tool.__signature__.parameters
722+
723+
tool = tool.bind_param("argA", lambda: 5)
724+
725+
assert len(tool.__signature__.parameters) == 1
726+
assert "argA" not in tool.__signature__.parameters
727+
728+
res = await tool(True)
729+
assert "argA" in res
730+
731+
@pytest.mark.asyncio
732+
async def test_bind_param_fail(self, tool_name, client):
733+
"""Tests 'bind_param' with a bound parameter that doesn't exist."""
734+
tool = await client.load_tool(tool_name)
735+
736+
assert len(tool.__signature__.parameters) == 2
737+
assert "argA" in tool.__signature__.parameters
738+
739+
with pytest.raises(Exception) as e:
740+
tool.bind_param("argC", lambda: 5)
741+
assert "unable to bind parameters: no parameter named argC" in str(e.value)
742+
743+
@pytest.mark.asyncio
744+
async def test_rebind_param_fail(self, tool_name, client):
745+
"""
746+
Tests that 'bind_param' fails when attempting to re-bind a
747+
parameter that has already been bound.
748+
"""
749+
tool = await client.load_tool(tool_name)
750+
751+
assert len(tool.__signature__.parameters) == 2
752+
assert "argA" in tool.__signature__.parameters
753+
754+
tool_with_bound_param = tool.bind_param("argA", lambda: 10)
755+
756+
assert len(tool_with_bound_param.__signature__.parameters) == 1
757+
assert "argA" not in tool_with_bound_param.__signature__.parameters
758+
759+
with pytest.raises(ValueError) as e:
760+
tool_with_bound_param.bind_param("argA", lambda: 20)
761+
762+
assert "cannot re-bind parameter: parameter 'argA' is already bound" in str(
763+
e.value
764+
)
765+
766+
@pytest.mark.asyncio
767+
async def test_bind_param_static_value_success(self, tool_name, client):
768+
"""
769+
Tests bind_param method with a static value.
770+
"""
771+
772+
bound_value = "Test value"
773+
774+
tool = await client.load_tool(tool_name)
775+
bound_tool = tool.bind_param("argB", bound_value)
776+
777+
assert bound_tool is not tool
778+
assert "argB" not in bound_tool.__signature__.parameters
779+
assert "argA" in bound_tool.__signature__.parameters
780+
781+
passed_value_a = 42
782+
res_payload = await bound_tool(argA=passed_value_a)
783+
784+
assert res_payload == {"argA": passed_value_a, "argB": bound_value}
785+
786+
@pytest.mark.asyncio
787+
async def test_bind_param_sync_callable_value_success(self, tool_name, client):
788+
"""
789+
Tests bind_param method with a sync callable value.
790+
"""
791+
792+
bound_value_result = True
793+
bound_sync_callable = Mock(return_value=bound_value_result)
794+
795+
tool = await client.load_tool(tool_name)
796+
bound_tool = tool.bind_param("argB", bound_sync_callable)
797+
798+
assert bound_tool is not tool
799+
assert "argB" not in bound_tool.__signature__.parameters
800+
assert "argA" in bound_tool.__signature__.parameters
801+
802+
passed_value_a = 42
803+
res_payload = await bound_tool(argA=passed_value_a)
804+
805+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
806+
bound_sync_callable.assert_called_once()
807+
808+
@pytest.mark.asyncio
809+
async def test_bind_param_async_callable_value_success(self, tool_name, client):
810+
"""
811+
Tests bind_param method with an async callable value.
812+
"""
813+
814+
bound_value_result = True
815+
bound_async_callable = AsyncMock(return_value=bound_value_result)
816+
817+
tool = await client.load_tool(tool_name)
818+
bound_tool = tool.bind_param("argB", bound_async_callable)
819+
820+
assert bound_tool is not tool
821+
assert "argB" not in bound_tool.__signature__.parameters
822+
assert "argA" in bound_tool.__signature__.parameters
823+
824+
passed_value_a = 42
825+
res_payload = await bound_tool(argA=passed_value_a)
826+
827+
assert res_payload == {"argA": passed_value_a, "argB": bound_value_result}
828+
bound_async_callable.assert_awaited_once()
829+
699830

700831
class TestUnusedParameterValidation:
701832
"""

0 commit comments

Comments
 (0)