Skip to content

Commit 7ff6355

Browse files
committed
chore: Consolidate auth header creation logic
Post adding the feature of adding client-level auth headers (#178), we have the logic for creating an auth header, from the given auth token getter name, in 3 different places. This PR unifies all of that logic into a single helper to improve maintenance, and make it easier to change the way we add suffix/prefix, and reduces WET code.
1 parent ebd317b commit 7ff6355

File tree

1 file changed

+9
-3
lines changed
  • packages/toolbox-core/src/toolbox_core

1 file changed

+9
-3
lines changed

packages/toolbox-core/src/toolbox_core/tool.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
# Validate conflicting Headers/Auth Tokens
9595
request_header_names = client_headers.keys()
9696
auth_token_names = [
97-
auth_token_name + "_token"
97+
self.__get_auth_header(auth_token_name)
9898
for auth_token_name in auth_service_token_getters.keys()
9999
]
100100
duplicates = request_header_names & auth_token_names
@@ -187,6 +187,11 @@ def __copy(
187187
client_headers=check(client_headers, self.__client_headers),
188188
)
189189

190+
def __get_auth_header(self, auth_token_name: str) -> str:
191+
"""Returns the formatted auth token header name."""
192+
return f"{auth_token_name}_token"
193+
194+
190195
async def __call__(self, *args: Any, **kwargs: Any) -> str:
191196
"""
192197
Asynchronously calls the remote tool with the provided arguments.
@@ -228,7 +233,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> str:
228233
# create headers for auth services
229234
headers = {}
230235
for auth_service, token_getter in self.__auth_service_token_getters.items():
231-
headers[f"{auth_service}_token"] = await resolve_value(token_getter)
236+
headers[self.__get_auth_header(auth_service)] = await resolve_value(token_getter)
232237
for client_header_name, client_header_val in self.__client_headers.items():
233238
headers[client_header_name] = await resolve_value(client_header_val)
234239

@@ -276,7 +281,8 @@ def add_auth_token_getters(
276281
# Validate duplicates with client headers
277282
request_header_names = self.__client_headers.keys()
278283
auth_token_names = [
279-
auth_token_name + "_token" for auth_token_name in incoming_services
284+
self.__get_auth_header(auth_token_name)
285+
for auth_token_name in incoming_services
280286
]
281287
duplicates = request_header_names & auth_token_names
282288
if duplicates:

0 commit comments

Comments
 (0)