Skip to content

Commit e8989a9

Browse files
committed
chore: Consolidate auth header creation logic
Post adding the feature of adding client-level auth headers (#178), we have the logic for creating an auth header, from the given auth token getter name, in 3 different places. This PR unifies all of that logic into a single helper to improve maintenance, and make it easier to change the way we add suffix/prefix, and reduces WET code.
1 parent 9a5f481 commit e8989a9

File tree

1 file changed

+9
-3
lines changed
  • packages/toolbox-core/src/toolbox_core

1 file changed

+9
-3
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
# Validate conflicting Headers/Auth Tokens
9595
request_header_names = client_headers.keys()
9696
auth_token_names = [
97-
auth_token_name + "_token"
97+
self.__get_auth_header(auth_token_name)
9898
for auth_token_name in auth_service_token_getters.keys()
9999
]
100100
duplicates = request_header_names & auth_token_names
@@ -159,6 +159,11 @@ def __copy(
159159
client_headers=check(client_headers, self.__client_headers),
160160
)
161161

162+
def __get_auth_header(self, auth_token_name: str) -> str:
163+
"""Returns the formatted auth token header name."""
164+
return f"{auth_token_name}_token"
165+
166+
162167
async def __call__(self, *args: Any, **kwargs: Any) -> str:
163168
"""
164169
Asynchronously calls the remote tool with the provided arguments.
@@ -200,7 +205,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
200205
# create headers for auth services
201206
headers = {}
202207
for auth_service, token_getter in self.__auth_service_token_getters.items():
203-
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
208+
headers[self.__get_auth_header(auth_service)] = await resolve_value(token_getter)
204209
for client_header_name, client_header_val in self.__client_headers.items():
205210
headers[client_header_name] = await resolve_value(client_header_val)
206211

@@ -248,7 +253,8 @@ def add_auth_token_getters(
248253
# Validate duplicates with client headers
249254
request_header_names = self.__client_headers.keys()
250255
auth_token_names = [
251-
auth_token_name + "_token" for auth_token_name in incoming_services
256+
self.__get_auth_header(auth_token_name)
257+
for auth_token_name in incoming_services
252258
]
253259
duplicates = request_header_names & auth_token_names
254260
if duplicates:

0 commit comments

Comments
 (0)