Skip to content

Commit dca6791

Browse files
committed
add optional arg
1 parent 8d89e60 commit dca6791

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

packages/toolbox-core/src/toolbox_core/auth_methods.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
import asyncio
3838
from datetime import datetime, timedelta, timezone
39-
from typing import Any, Dict
39+
from typing import Any, Dict, Optional
4040

4141
import google.auth
4242
from google.auth.exceptions import GoogleAuthError
@@ -94,7 +94,7 @@ def _update_cache(new_token: str) -> None:
9494
raise ValueError(f"Failed to validate and cache the new token: {e}") from e
9595

9696

97-
def get_google_id_token(audience: str) -> str:
97+
def get_google_id_token(audience: Optional[str] = None) -> str:
9898
"""
9999
Synchronously fetches a Google ID token for a specific audience.
100100
This function uses Application Default Credentials for local systems
@@ -125,8 +125,12 @@ def get_google_id_token(audience: str) -> str:
125125
if new_id_token:
126126
_update_cache(new_id_token)
127127
return BEARER_TOKEN_PREFIX + new_id_token
128+
129+
if audience is None:
130+
raise Exception('You are not authenticating using User Credentials.'
131+
' Please set the audience string to the Toolbox service URL to get the Google ID token.')
128132

129-
# Get credentials for Google Cloud environments
133+
# Get credentials for Google Cloud environments or for service account key files
130134
try:
131135
request = Request()
132136
new_token = id_token.fetch_id_token(request, audience)
@@ -139,7 +143,7 @@ def get_google_id_token(audience: str) -> str:
139143
) from e
140144

141145

142-
async def aget_google_id_token(audience: str) -> str:
146+
async def aget_google_id_token(audience: Optional[str] = None) -> str:
143147
"""
144148
Asynchronously fetches a Google ID token for a specific audience.
145149
This function uses Application Default Credentials for local systems

packages/toolbox-core/tests/test_auth_methods.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ async def test_aget_google_id_token_success_first_call(
6969
assert auth_methods._token_cache["token"] == MOCK_ID_TOKEN
7070
assert auth_methods._token_cache["expires_at"] == MOCK_EXPIRY_DATETIME
7171

72-
@pytest.mark.asyncio
7372
@patch("toolbox_core.auth_methods.google.auth.default")
74-
async def test_aget_google_id_token_success_uses_cache(self, mock_get_token):
73+
async def test_aget_google_id_token_success_uses_cache(self, mock_default):
7574
"""Tests that subsequent calls use the cached token if valid."""
7675
# Prime the cache with a valid token
7776
auth_methods._token_cache["token"] = MOCK_ID_TOKEN
@@ -81,8 +80,8 @@ async def test_aget_google_id_token_success_uses_cache(self, mock_get_token):
8180

8281
token = await auth_methods.aget_google_id_token(MOCK_AUDIENCE)
8382

84-
# The underlying sync function should not be called if cache is valid
85-
mock_get_token.assert_not_called()
83+
# The underlying auth function should not be called if cache is valid
84+
mock_default.assert_not_called()
8685
assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_ID_TOKEN}"
8786

8887
@patch("toolbox_core.auth_methods.id_token.verify_oauth2_token")
@@ -112,6 +111,23 @@ async def test_aget_google_id_token_refreshes_expired_cache(
112111
assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_ID_TOKEN}"
113112
assert auth_methods._token_cache["token"] == MOCK_ID_TOKEN
114113

114+
@patch("toolbox_core.auth_methods.id_token.fetch_id_token")
115+
@patch(
116+
"toolbox_core.auth_methods.google.auth.default",
117+
return_value=(MagicMock(id_token=None), MOCK_PROJECT_ID),
118+
)
119+
async def test_aget_raises_if_no_audience_and_no_local_token(
120+
self, mock_default, mock_fetch
121+
):
122+
"""Tests that the async function propagates the missing audience exception."""
123+
error_msg = "You are not authenticating using User Credentials."
124+
with pytest.raises(Exception, match=error_msg):
125+
# Call without audience to trigger the error path
126+
await auth_methods.aget_google_id_token()
127+
128+
mock_default.assert_called_once()
129+
mock_fetch.assert_not_called()
130+
115131

116132
class TestSyncAuthMethods:
117133
"""Tests for synchronous Google ID token fetching."""
@@ -196,3 +212,21 @@ def test_get_google_id_token_validation_failure(
196212

197213
# Verify cache is cleared on validation failure
198214
assert auth_methods._token_cache["token"] is None
215+
216+
@patch("toolbox_core.auth_methods.id_token.fetch_id_token")
217+
@patch(
218+
"toolbox_core.auth_methods.google.auth.default",
219+
# Simulate credentials that DON'T have a pre-existing id_token
220+
return_value=(MagicMock(id_token=None), MOCK_PROJECT_ID),
221+
)
222+
def test_get_raises_if_no_audience_and_no_local_token(
223+
self, mock_default, mock_fetch
224+
):
225+
"""Tests exception is raised if audience is required but not provided."""
226+
error_msg = "You are not authenticating using User Credentials."
227+
with pytest.raises(Exception, match=error_msg):
228+
# Call without an audience to trigger the error path
229+
auth_methods.get_google_id_token()
230+
231+
mock_default.assert_called_once()
232+
mock_fetch.assert_not_called()

0 commit comments

Comments
 (0)