Skip to content

Commit aa0664f

Browse files
authored
fix: add optional clock_skew parameter to token helpers (#333)
* fix: fix token clock skew issue * lint * fix unit tests * change clock skew * change clock skew to 10 instead of 60 * Made clock skew an optional arg * added check for clock skew values * make default clock skew = 0 * fix tests * fix arg name * rename arg * add suggestion
1 parent 5db97ae commit aa0664f

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

packages/toolbox-core/src/toolbox_core/auth_methods.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
# --- Constants ---
4242
BEARER_TOKEN_PREFIX = "Bearer "
4343
CACHE_REFRESH_MARGIN = timedelta(seconds=60)
44+
DEFAULT_CLOCK_SKEW = 0
4445

4546
_token_cache: Dict[str, Any] = {
4647
"token": None,
@@ -57,7 +58,7 @@ def _is_token_valid() -> bool:
5758
)
5859

5960

60-
def _update_cache(new_token: str) -> None:
61+
def _update_cache(new_token: str, clock_skew_in_seconds: int) -> None:
6162
"""
6263
Validates a new token, extracts its expiry, and updates the cache.
6364
@@ -71,7 +72,9 @@ def _update_cache(new_token: str) -> None:
7172
# verify_oauth2_token not only decodes but also validates the token's
7273
# signature and claims against Google's public keys.
7374
# It's a synchronous, CPU-bound operation, safe for async contexts.
74-
claims = id_token.verify_oauth2_token(new_token, Request())
75+
claims = id_token.verify_oauth2_token(
76+
new_token, Request(), clock_skew_in_seconds=clock_skew_in_seconds
77+
)
7578

7679
expiry_timestamp = claims.get("exp")
7780
if not expiry_timestamp:
@@ -89,7 +92,15 @@ def _update_cache(new_token: str) -> None:
8992
raise ValueError(f"Failed to validate and cache the new token: {e}") from e
9093

9194

92-
def get_google_token_from_aud(audience: Optional[str] = None) -> str:
95+
def get_google_token_from_aud(
96+
clock_skew_in_seconds: int = 0, audience: Optional[str] = None
97+
) -> str:
98+
if clock_skew_in_seconds < 0 or clock_skew_in_seconds > 60:
99+
raise ValueError(
100+
f"Illegal clock_skew_in_seconds value: {clock_skew_in_seconds}. Must be between 0 and 60"
101+
", inclusive."
102+
)
103+
93104
if _is_token_valid():
94105
return BEARER_TOKEN_PREFIX + _token_cache["token"]
95106

@@ -102,7 +113,7 @@ def get_google_token_from_aud(audience: Optional[str] = None) -> str:
102113
if hasattr(credentials, "id_token"):
103114
new_id_token = getattr(credentials, "id_token", None)
104115
if new_id_token:
105-
_update_cache(new_id_token)
116+
_update_cache(new_id_token, clock_skew_in_seconds)
106117
return BEARER_TOKEN_PREFIX + new_id_token
107118

108119
if audience is None:
@@ -115,7 +126,7 @@ def get_google_token_from_aud(audience: Optional[str] = None) -> str:
115126
try:
116127
request = Request()
117128
new_token = id_token.fetch_id_token(request, audience)
118-
_update_cache(new_token)
129+
_update_cache(new_token, clock_skew_in_seconds)
119130
return BEARER_TOKEN_PREFIX + _token_cache["token"]
120131

121132
except GoogleAuthError as e:
@@ -124,15 +135,20 @@ def get_google_token_from_aud(audience: Optional[str] = None) -> str:
124135
) from e
125136

126137

127-
def get_google_id_token(audience: Optional[str] = None) -> Callable[[], str]:
138+
def get_google_id_token(
139+
audience: Optional[str] = None, clock_skew_in_seconds: int = DEFAULT_CLOCK_SKEW
140+
) -> Callable[[], str]:
128141
"""
129142
Returns a SYNC function that, when called, fetches a Google ID token.
130143
This function uses Application Default Credentials for local systems
131144
and standard google auth libraries for Google Cloud environments.
132145
It caches the token in memory.
133146
134147
Args:
135-
audience: The audience for the ID token (e.g., a service URL or client ID).
148+
audience: The audience for the ID token (e.g., a service URL or client
149+
ID).
150+
clock_skew_in_seconds: The number of seconds to tolerate when checking the token.
151+
Must be between 0-60. Defaults to 0.
136152
137153
Returns:
138154
A function that when executed returns string in the format "Bearer <google_id_token>".
@@ -143,13 +159,13 @@ def get_google_id_token(audience: Optional[str] = None) -> Callable[[], str]:
143159
"""
144160

145161
def _token_getter() -> str:
146-
return get_google_token_from_aud(audience)
162+
return get_google_token_from_aud(clock_skew_in_seconds, audience)
147163

148164
return _token_getter
149165

150166

151167
def aget_google_id_token(
152-
audience: Optional[str] = None,
168+
audience: Optional[str] = None, clock_skew_in_seconds: int = DEFAULT_CLOCK_SKEW
153169
) -> Callable[[], Coroutine[Any, Any, str]]:
154170
"""
155171
Returns an ASYNC function that, when called, fetches a Google ID token.
@@ -158,7 +174,10 @@ def aget_google_id_token(
158174
It caches the token in memory.
159175
160176
Args:
161-
audience: The audience for the ID token (e.g., a service URL or client ID).
177+
audience: The audience for the ID token (e.g., a service URL or client
178+
ID).
179+
clock_skew_in_seconds: The number of seconds to tolerate when checking the token.
180+
Must be between 0-60. Defaults to 0.
162181
163182
Returns:
164183
An async function that when executed returns string in the format "Bearer <google_id_token>".
@@ -169,6 +188,8 @@ def aget_google_id_token(
169188
"""
170189

171190
async def _token_getter() -> str:
172-
return await asyncio.to_thread(get_google_token_from_aud, audience)
191+
return await asyncio.to_thread(
192+
get_google_token_from_aud, clock_skew_in_seconds, audience
193+
)
173194

174195
return _token_getter

packages/toolbox-core/tests/test_auth_methods.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_get_google_id_token_success_local_creds(
154154
mock_default.assert_called_once_with()
155155
mock_session.assert_called_once_with(mock_creds)
156156
mock_creds.refresh.assert_called_once_with(mock_request_instance)
157-
mock_verify.assert_called_once_with(MOCK_ID_TOKEN, ANY)
157+
mock_verify.assert_called_once_with(MOCK_ID_TOKEN, ANY, clock_skew_in_seconds=0)
158158
assert token == f"{auth_methods.BEARER_TOKEN_PREFIX}{MOCK_ID_TOKEN}"
159159
assert auth_methods._token_cache["token"] == MOCK_ID_TOKEN
160160
assert auth_methods._token_cache["expires_at"] == MOCK_EXPIRY_DATETIME

0 commit comments

Comments
 (0)