Skip to content

Commit 80d320b

Browse files
authored
Merge pull request #624 from atlanhq/APP-6579
APP-6156: Migrated from `TLS` to `ContextVars` to support both multithreading and async environments
2 parents 06fe81f + 20a6324 commit 80d320b

File tree

2 files changed

+65
-26
lines changed

2 files changed

+65
-26
lines changed

pyatlan/client/atlan.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,12 @@ def get_session():
150150

151151

152152
class AtlanClient(BaseSettings):
153-
_current_client_tls: ClassVar[local] = local() # Thread-local storage (TLS)
153+
_current_client_ctx: ClassVar[ContextVar] = ContextVar(
154+
"_current_client_ctx", default=None
155+
)
156+
_401_has_retried_ctx: ClassVar[ContextVar] = ContextVar(
157+
"_401_has_retried_ctx", default=False
158+
)
154159
base_url: Union[Literal["INTERNAL"], HttpUrl]
155160
api_key: str
156161
connect_timeout: float = 30.0 # 30 secs
@@ -190,37 +195,24 @@ class AtlanClient(BaseSettings):
190195
class Config:
191196
env_prefix = "atlan_"
192197

193-
@classmethod
194-
def init_for_multithreading(cls, client: AtlanClient):
195-
"""
196-
Prepares the given client for use in multi-threaded environments.
197-
198-
This sets the thread-local context and resets internal retry flags
199-
to ensure correct behavior when using the client across multiple threads.
200-
"""
201-
AtlanClient.set_current_client(client)
202-
client._401_tls.has_retried = False
203-
204198
@classmethod
205199
def set_current_client(cls, client: AtlanClient):
206200
"""
207-
Sets the current client to thread-local storage (TLS)
201+
Sets the current client context
208202
"""
209203
if not isinstance(client, AtlanClient):
210204
raise ErrorCode.MISSING_ATLAN_CLIENT.exception_with_parameters()
211-
cls._current_client_tls.client = client
205+
cls._current_client_ctx.set(client)
212206

213207
@classmethod
214208
def get_current_client(cls) -> AtlanClient:
215209
"""
216-
Retrieves the current client
210+
Retrieves the current client context
217211
"""
218-
if (
219-
not hasattr(cls._current_client_tls, "client")
220-
or not cls._current_client_tls.client
221-
):
212+
client = cls._current_client_ctx.get()
213+
if not client:
222214
raise ErrorCode.NO_ATLAN_CLIENT_AVAILABLE.exception_with_parameters()
223-
return cls._current_client_tls.client
215+
return client
224216

225217
def __init__(self, **data):
226218
super().__init__(**data)
@@ -233,7 +225,8 @@ def __init__(self, **data):
233225
adapter = HTTPAdapter(max_retries=self.retry)
234226
session.mount(HTTPS_PREFIX, adapter)
235227
session.mount(HTTP_PREFIX, adapter)
236-
AtlanClient.init_for_multithreading(self)
228+
AtlanClient.set_current_client(self)
229+
self._401_has_retried_ctx.set(False)
237230

238231
@property
239232
def admin(self) -> AdminClient:
@@ -482,11 +475,11 @@ def _call_api_internal(
482475
# - But if the next response is != 401 (e.g. 403), and `has_retried = True`,
483476
# then we should reset `has_retried = False` so that future 401s can trigger a new token refresh.
484477
if (
485-
self._401_tls.has_retried
478+
self._401_has_retried_ctx.get()
486479
and response.status_code
487480
!= ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
488481
):
489-
self._401_tls.has_retried = False
482+
self._401_has_retried_ctx.set(False)
490483

491484
if response.status_code == api.expected_status:
492485
try:
@@ -571,7 +564,7 @@ def _call_api_internal(
571564
# on authentication failure (token may have expired)
572565
if (
573566
self._user_id
574-
and not self._401_tls.has_retried
567+
and not self._401_has_retried_ctx.get()
575568
and response.status_code
576569
== ErrorCode.AUTHENTICATION_PASSTHROUGH.http_error_code
577570
):
@@ -732,7 +725,7 @@ def _handle_401_token_refresh(
732725
)
733726
raise
734727
self.api_key = new_token
735-
self._401_tls.has_retried = True
728+
self._401_has_retried_ctx.set(True)
736729
params["headers"]["authorization"] = f"Bearer {self.api_key}"
737730
self._request_params["headers"]["authorization"] = f"Bearer {self.api_key}"
738731
LOGGER.debug("Successfully completed 401 automatic token refresh.")

pyatlan/utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import random
99
import re
1010
import time
11-
from contextvars import ContextVar
11+
from concurrent.futures import ThreadPoolExecutor
12+
from contextvars import ContextVar, copy_context
1213
from datetime import datetime
1314
from enum import Enum, EnumMeta
1415
from functools import reduce, wraps
@@ -473,6 +474,51 @@ def validate_single_required_field(field_names: List[str], values: List[Any]):
473474
)
474475

475476

477+
class PyAtlanThreadPoolExecutor(ThreadPoolExecutor):
478+
"""
479+
A ThreadPoolExecutor that preserves context variables (e.g: `AtlanClient`)
480+
across threads—useful when running SDK methods in multithreading and async environments.
481+
482+
For example:
483+
484+
```
485+
import asyncio
486+
from functools import partial
487+
from pyatlan.model.assets import Table
488+
from pyatlan.client.atlan import AtlanClient
489+
from pyatlan.utils import PyAtlanThreadPoolExecutor
490+
491+
client = AtlanClient()
492+
493+
async def fetch_asset():
494+
loop = asyncio.get_event_loop()
495+
sdk_func = partial(
496+
client.asset.get_by_guid, "ef1ffe2c-8fc9-433a-8cf8-b4583f2d2375", Table
497+
)
498+
return await loop.run_in_executor(
499+
executor=PyAtlanThreadPoolExecutor(), func=sdk_func
500+
)
501+
502+
result = asyncio.run(fetch_asset())
503+
```
504+
"""
505+
506+
_SDK_CONTEXT_VAR_NAMES = [
507+
"_current_client_ctx",
508+
"_401_has_retried_ctx",
509+
]
510+
511+
def submit(self, fn, /, *args, **kwargs):
512+
ctx_vars = copy_context().items()
513+
514+
def _fn():
515+
for var, value in ctx_vars:
516+
# Only set the context variables that are used by the SDK
517+
if var.name in self._SDK_CONTEXT_VAR_NAMES:
518+
var.set(value)
519+
return fn(*args, **kwargs)
520+
521+
476522
class ExtendableEnumMeta(EnumMeta):
477523
def __init__(cls, name, bases, namespace):
478524
super().__init__(name, bases, namespace)

0 commit comments

Comments
 (0)