Skip to content

Commit 9c21d46

Browse files
authored
Merge pull request #592 from atlanhq/APP-6218
APP-6218: Improve 401 token refresh handling
2 parents d401231 + a143939 commit 9c21d46

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

pyatlan/client/atlan.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,24 @@ def _call_api_internal(
462462
LOGGER.debug("HTTP Status: %s", response.status_code)
463463
if response is None:
464464
return None
465-
if response.status_code == api.expected_status:
465+
466+
# Reset `has_retried` flag if:
467+
# - SDK already attempted a 401 token refresh (`has_retried = True`)
468+
# - and the current response status code is NOT 401
469+
#
470+
# Real-world scenario:
471+
# - First 401 triggers `_handle_401_token_refresh`, setting `has_retried = True`
472+
# - If the next response is also 401 → SDK returns 401 (won’t retry again)
473+
# - But if the next response is != 401 (e.g. 403), and `has_retried = True`,
474+
# then we should reset `has_retried = False` so that future 401s can trigger a new token refresh.
475+
if (
476+
self._401_tls.has_retried
477+
and response.status_code
478+
!= ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
479+
):
466480
self._401_tls.has_retried = False
481+
482+
if response.status_code == api.expected_status:
467483
try:
468484
if (
469485
response.content is None
@@ -562,10 +578,10 @@ def _call_api_internal(
562578
)
563579
except Exception as e:
564580
LOGGER.debug(
565-
"Failed to impersonate user %s for 401 token refresh. Not retrying. Error: %s",
566-
self._user_id,
581+
"API call failed after a successful 401 token refresh. Error details: %s",
567582
e,
568583
)
584+
raise
569585

570586
if error_code and error_message:
571587
error = ERROR_CODE_FOR_HTTP_STATUS.get(
@@ -697,12 +713,31 @@ def _handle_401_token_refresh(
697713
698714
returns: HTTP response received after retrying the request with the refreshed token
699715
"""
700-
new_token = self.impersonate.user(user_id=self._user_id)
716+
try:
717+
new_token = self.impersonate.user(user_id=self._user_id)
718+
except Exception as e:
719+
LOGGER.debug(
720+
"Failed to impersonate user %s for 401 token refresh. Not retrying. Error: %s",
721+
self._user_id,
722+
e,
723+
)
724+
raise
701725
self.api_key = new_token
702726
self._401_tls.has_retried = True
703727
params["headers"]["authorization"] = f"Bearer {self.api_key}"
704728
self._request_params["headers"]["authorization"] = f"Bearer {self.api_key}"
705729
LOGGER.debug("Successfully completed 401 automatic token refresh.")
730+
731+
# Adding a short delay after token refresh
732+
# This helps ensure that when we fetch typedefs using the new token,
733+
# the backend has fully recognized the token as valid.
734+
# Without this delay, we occasionally get an empty response `[]` from the API,
735+
# likely because the backend hasn’t fully propagated token validity yet.
736+
import time
737+
738+
time.sleep(5)
739+
740+
# Retry the API call with the new token
706741
return self._call_api_internal(
707742
api,
708743
path,

tests/integration/test_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
SearchLogResults,
1616
SearchLogViewResults,
1717
)
18-
from pyatlan.errors import AuthenticationError, NotFoundError
18+
from pyatlan.errors import AuthenticationError, InvalidRequestError, NotFoundError
1919
from pyatlan.model.api_tokens import ApiToken
2020
from pyatlan.model.assets import (
2121
Asset,
@@ -1506,8 +1506,8 @@ def test_client_401_token_refresh(
15061506
# Test that providing an invalid user ID results in the same authentication error
15071507
client._user_id = "invalid-user-id"
15081508
with pytest.raises(
1509-
AuthenticationError,
1510-
match="Server responded with an authentication error 401",
1509+
InvalidRequestError,
1510+
match="Missing privileged credentials to impersonate users",
15111511
):
15121512
FluentSearch().where(CompoundQuery.active_assets()).where(
15131513
CompoundQuery.asset_type(AtlasGlossary)

0 commit comments

Comments
 (0)