@@ -150,7 +150,12 @@ def get_session():
150
150
151
151
152
152
class AtlanClient (BaseSettings ):
153
- _current_client_tls : ClassVar [local ] = local () # Thread-local storage (TLS)
153
+ _current_client_ctx : ClassVar [ContextVar ] = ContextVar (
154
+ "_current_client_ctx" , default = None
155
+ )
156
+ _401_has_retried_ctx : ClassVar [ContextVar ] = ContextVar (
157
+ "_401_has_retried_ctx" , default = False
158
+ )
154
159
base_url : Union [Literal ["INTERNAL" ], HttpUrl ]
155
160
api_key : str
156
161
connect_timeout : float = 30.0 # 30 secs
@@ -190,37 +195,24 @@ class AtlanClient(BaseSettings):
190
195
class Config :
191
196
env_prefix = "atlan_"
192
197
193
- @classmethod
194
- def init_for_multithreading (cls , client : AtlanClient ):
195
- """
196
- Prepares the given client for use in multi-threaded environments.
197
-
198
- This sets the thread-local context and resets internal retry flags
199
- to ensure correct behavior when using the client across multiple threads.
200
- """
201
- AtlanClient .set_current_client (client )
202
- client ._401_tls .has_retried = False
203
-
204
198
@classmethod
205
199
def set_current_client (cls , client : AtlanClient ):
206
200
"""
207
- Sets the current client to thread-local storage (TLS)
201
+ Sets the current client context
208
202
"""
209
203
if not isinstance (client , AtlanClient ):
210
204
raise ErrorCode .MISSING_ATLAN_CLIENT .exception_with_parameters ()
211
- cls ._current_client_tls . client = client
205
+ cls ._current_client_ctx . set ( client )
212
206
213
207
@classmethod
214
208
def get_current_client (cls ) -> AtlanClient :
215
209
"""
216
- Retrieves the current client
210
+ Retrieves the current client context
217
211
"""
218
- if (
219
- not hasattr (cls ._current_client_tls , "client" )
220
- or not cls ._current_client_tls .client
221
- ):
212
+ client = cls ._current_client_ctx .get ()
213
+ if not client :
222
214
raise ErrorCode .NO_ATLAN_CLIENT_AVAILABLE .exception_with_parameters ()
223
- return cls . _current_client_tls . client
215
+ return client
224
216
225
217
def __init__ (self , ** data ):
226
218
super ().__init__ (** data )
@@ -233,7 +225,8 @@ def __init__(self, **data):
233
225
adapter = HTTPAdapter (max_retries = self .retry )
234
226
session .mount (HTTPS_PREFIX , adapter )
235
227
session .mount (HTTP_PREFIX , adapter )
236
- AtlanClient .init_for_multithreading (self )
228
+ AtlanClient .set_current_client (self )
229
+ self ._401_has_retried_ctx .set (False )
237
230
238
231
@property
239
232
def admin (self ) -> AdminClient :
@@ -482,11 +475,11 @@ def _call_api_internal(
482
475
# - But if the next response is != 401 (e.g. 403), and `has_retried = True`,
483
476
# then we should reset `has_retried = False` so that future 401s can trigger a new token refresh.
484
477
if (
485
- self ._401_tls . has_retried
478
+ self ._401_has_retried_ctx . get ()
486
479
and response .status_code
487
480
!= ErrorCode .AUTHENTICATION_PASSTHROUGH .http_error_code
488
481
):
489
- self ._401_tls . has_retried = False
482
+ self ._401_has_retried_ctx . set ( False )
490
483
491
484
if response .status_code == api .expected_status :
492
485
try :
@@ -571,7 +564,7 @@ def _call_api_internal(
571
564
# on authentication failure (token may have expired)
572
565
if (
573
566
self ._user_id
574
- and not self ._401_tls . has_retried
567
+ and not self ._401_has_retried_ctx . get ()
575
568
and response .status_code
576
569
== ErrorCode .AUTHENTICATION_PASSTHROUGH .http_error_code
577
570
):
@@ -732,7 +725,7 @@ def _handle_401_token_refresh(
732
725
)
733
726
raise
734
727
self .api_key = new_token
735
- self ._401_tls . has_retried = True
728
+ self ._401_has_retried_ctx . set ( True )
736
729
params ["headers" ]["authorization" ] = f"Bearer { self .api_key } "
737
730
self ._request_params ["headers" ]["authorization" ] = f"Bearer { self .api_key } "
738
731
LOGGER .debug ("Successfully completed 401 automatic token refresh." )
0 commit comments