Skip to content

Commit f4819eb

Browse files
committed
[change] Removed bool contextvar for managing 401 token refresh
- Also removed `PyAtlanThreadPoolExecutor` as it was not used anymore.
1 parent f040a7f commit f4819eb

File tree

3 files changed

+8
-127
lines changed

3 files changed

+8
-127
lines changed

pyatlan/client/atlan.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,8 @@
1212
from contextvars import ContextVar
1313
from http import HTTPStatus
1414
from importlib.resources import read_text
15-
from threading import local
1615
from types import SimpleNamespace
17-
from typing import (
18-
Any,
19-
ClassVar,
20-
Dict,
21-
Generator,
22-
List,
23-
Literal,
24-
Optional,
25-
Set,
26-
Type,
27-
Union,
28-
)
16+
from typing import Any, Dict, Generator, List, Literal, Optional, Set, Type, Union
2917
from urllib.parse import urljoin
3018
from warnings import warn
3119

@@ -150,17 +138,14 @@ def get_session():
150138

151139

152140
class AtlanClient(BaseSettings):
153-
_401_has_retried_ctx: ClassVar[ContextVar] = ContextVar(
154-
"_401_has_retried_ctx", default=False
155-
)
156141
base_url: Union[Literal["INTERNAL"], HttpUrl]
157142
api_key: str
158143
connect_timeout: float = 30.0 # 30 secs
159144
read_timeout: float = 900.0 # 15 mins
160145
retry: Retry = DEFAULT_RETRY
146+
_401_has_retried: bool = PrivateAttr(default=False)
161147
_session: requests.Session = PrivateAttr(default_factory=get_session)
162148
_request_params: dict = PrivateAttr()
163-
_401_tls: local = local()
164149
_user_id: Optional[str] = PrivateAttr(default=None)
165150
_workflow_client: Optional[WorkflowClient] = PrivateAttr(default=None)
166151
_credential_client: Optional[CredentialClient] = PrivateAttr(default=None)
@@ -203,7 +188,7 @@ def __init__(self, **data):
203188
adapter = HTTPAdapter(max_retries=self.retry)
204189
session.mount(HTTPS_PREFIX, adapter)
205190
session.mount(HTTP_PREFIX, adapter)
206-
self._401_has_retried_ctx.set(False)
191+
self._401_has_retried = False
207192

208193
@property
209194
def admin(self) -> AdminClient:
@@ -426,11 +411,11 @@ def _call_api_internal(
426411
# - But if the next response is != 401 (e.g. 403), and `has_retried = True`,
427412
# then we should reset `has_retried = False` so that future 401s can trigger a new token refresh.
428413
if (
429-
self._401_has_retried_ctx.get()
414+
self._401_has_retried
430415
and response.status_code
431416
!= ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
432417
):
433-
self._401_has_retried_ctx.set(False)
418+
self._401_has_retried = False
434419

435420
if response.status_code == api.expected_status:
436421
try:
@@ -521,7 +506,7 @@ def _call_api_internal(
521506
# on authentication failure (token may have expired)
522507
if (
523508
self._user_id
524-
and not self._401_has_retried_ctx.get()
509+
and not self._401_has_retried
525510
and response.status_code
526511
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
527512
):
@@ -680,7 +665,7 @@ def _handle_401_token_refresh(
680665
)
681666
raise
682667
self.api_key = new_token
683-
self._401_has_retried_ctx.set(True)
668+
self._401_has_retried = True
684669
params["headers"]["authorization"] = f"Bearer {self.api_key}"
685670
self._request_params["headers"]["authorization"] = f"Bearer {self.api_key}"
686671
LOGGER.debug("Successfully completed 401 automatic token refresh.")

pyatlan/utils.py

Lines changed: 1 addition & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import random
99
import re
1010
import time
11-
from concurrent.futures import ThreadPoolExecutor
12-
from contextvars import ContextVar, copy_context
11+
from contextvars import ContextVar
1312
from datetime import datetime
1413
from enum import Enum, EnumMeta
1514
from functools import reduce, wraps
@@ -474,53 +473,6 @@ def validate_single_required_field(field_names: List[str], values: List[Any]):
474473
)
475474

476475

477-
class PyAtlanThreadPoolExecutor(ThreadPoolExecutor):
478-
"""
479-
A ThreadPoolExecutor that preserves context variables (e.g: `AtlanClient`)
480-
across threads—useful when running SDK methods in multithreading and async environments.
481-
482-
For example:
483-
484-
```
485-
import asyncio
486-
from functools import partial
487-
from pyatlan.model.assets import Table
488-
from pyatlan.client.atlan import AtlanClient
489-
from pyatlan.utils import PyAtlanThreadPoolExecutor
490-
491-
client = AtlanClient()
492-
493-
async def fetch_asset():
494-
loop = asyncio.get_running_loop()
495-
sdk_func = partial(
496-
client.asset.get_by_guid, "ef1ffe2c-8fc9-433a-8cf8-b4583f2d2375", Table
497-
)
498-
return await loop.run_in_executor(
499-
executor=PyAtlanThreadPoolExecutor(), func=sdk_func
500-
)
501-
502-
result = asyncio.run(fetch_asset())
503-
```
504-
"""
505-
506-
_SDK_CONTEXT_VAR_NAMES = [
507-
"_current_client_ctx",
508-
"_401_has_retried_ctx",
509-
]
510-
511-
def submit(self, fn, /, *args, **kwargs):
512-
ctx_vars = copy_context().items()
513-
514-
def _fn():
515-
for var, value in ctx_vars:
516-
# Only set the context variables that are used by the SDK
517-
if var.name in self._SDK_CONTEXT_VAR_NAMES:
518-
var.set(value)
519-
return fn(*args, **kwargs)
520-
521-
return super().submit(_fn)
522-
523-
524476
class ExtendableEnumMeta(EnumMeta):
525477
def __init__(cls, name, bases, namespace):
526478
super().__init__(name, bases, namespace)

tests/unit/test_client.py

Lines changed: 0 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2152,62 +2152,6 @@ def test_atlan_call_api_server_error_messages_with_causes(
21522152
client.asset.save(glossary)
21532153

21542154

2155-
# @pytest.mark.parametrize("thread_count", [3]) # Run with three threads
2156-
# def test_atlan_client_tls(thread_count):
2157-
# """Tests that AtlanClient instances remain isolated across multiple threads."""
2158-
# validation_results = {}
2159-
# results_lock = threading.Lock()
2160-
2161-
# def _test_atlan_client_isolation(name, api_key1, api_key2, api_key3):
2162-
# """Creates three AtlanClient instances within the same thread and verifies isolation."""
2163-
# # Instantiate three separate AtlanClient instances
2164-
# client1 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key1)
2165-
# time.sleep(0.2)
2166-
# observed1 = client1.get_current_client().api_key # Should match api_key1
2167-
2168-
# client2 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key2)
2169-
# time.sleep(0.2)
2170-
# observed2 = client2.get_current_client().api_key # Should match api_key2
2171-
2172-
# client3 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key3)
2173-
# time.sleep(0.2)
2174-
# observed3 = client3.get_current_client().api_key # Should match api_key3
2175-
2176-
# # Store results in a thread-safe way
2177-
# with results_lock:
2178-
# validation_results[name] = (observed1, observed2, observed3)
2179-
2180-
# # Define unique API keys for each thread
2181-
# api_keys = [
2182-
# ("API_KEY_1A", "API_KEY_1B", "API_KEY_1C"),
2183-
# ("API_KEY_2A", "API_KEY_2B", "API_KEY_2C"),
2184-
# ("API_KEY_3A", "API_KEY_3B", "API_KEY_3C"),
2185-
# ]
2186-
2187-
# threads = []
2188-
# for i in range(thread_count):
2189-
# thread = threading.Thread(
2190-
# target=_test_atlan_client_isolation,
2191-
# args=(f"thread{i + 1}", *api_keys[i]),
2192-
# )
2193-
# threads.append(thread)
2194-
# thread.start()
2195-
2196-
# # Wait for all threads to finish
2197-
# for thread in threads:
2198-
# thread.join()
2199-
2200-
# # Validate that each thread's clients retained their assigned API keys
2201-
# for i in range(thread_count):
2202-
# thread_name = f"thread{i + 1}"
2203-
# expected_keys = api_keys[i]
2204-
2205-
# assert validation_results[thread_name] == expected_keys, (
2206-
# f"Clients were overwritten across threads! "
2207-
# f"{thread_name} saw {validation_results[thread_name]} instead of {expected_keys}"
2208-
# )
2209-
2210-
22112155
class TestBatch:
22122156
def test_init(self, mock_atlan_client):
22132157
sut = Batch(client=mock_atlan_client, max_size=10)

0 commit comments

Comments
 (0)