1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# Copyright 2022 Atlan Pte. Ltd.
3
- from threading import Lock
4
- from typing import Dict , Iterable , Optional
3
+ from __future__ import annotations
4
+
5
+ from threading import Lock , local
6
+ from typing import TYPE_CHECKING , Dict , Iterable , Optional
5
7
6
- from pyatlan .client .token import TokenClient
7
- from pyatlan .client .user import UserClient
8
8
from pyatlan .errors import ErrorCode
9
9
from pyatlan .model .constants import SERVICE_ACCOUNT_
10
10
11
+ if TYPE_CHECKING :
12
+ from pyatlan .client .atlan import AtlanClient
13
+
11
14
lock = Lock ()
15
+ user_cache_tls = local () # Thread-local storage (TLS)
12
16
13
17
14
18
class UserCache :
15
19
"""
16
20
Lazily-loaded cache for translating Atlan-internal users into their various IDs.
17
21
"""
18
22
19
- caches : Dict [int , "UserCache" ] = {}
23
+ def __init__ (self , client : AtlanClient ):
24
+ self .client : AtlanClient = client
25
+ self .map_id_to_name : Dict [str , str ] = {}
26
+ self .map_name_to_id : Dict [str , str ] = {}
27
+ self .map_email_to_id : Dict [str , str ] = {}
28
+ self .lock : Lock = Lock ()
20
29
21
30
@classmethod
22
- def get_cache (cls ) -> " UserCache" :
31
+ def get_cache (cls , client : Optional [ AtlanClient ] = None ) -> UserCache :
23
32
from pyatlan .client .atlan import AtlanClient
24
33
25
34
with lock :
26
- client = AtlanClient .get_default_client ()
35
+ client = client or AtlanClient .get_default_client ()
27
36
cache_key = client .cache_key
28
- if cache_key not in cls .caches :
29
- cls .caches [cache_key ] = UserCache (
30
- user_client = client .user , token_client = client .token
31
- )
32
- return cls .caches [cache_key ]
37
+
38
+ if not hasattr (user_cache_tls , "caches" ):
39
+ user_cache_tls .caches = {}
40
+
41
+ if cache_key not in user_cache_tls .caches :
42
+ cache_instance = UserCache (client = client )
43
+ cache_instance ._refresh_cache () # Refresh on new cache instance
44
+ user_cache_tls .caches [cache_key ] = cache_instance
45
+
46
+ return user_cache_tls .caches [cache_key ]
33
47
34
48
@classmethod
35
49
def get_id_for_name (cls , name : str ) -> Optional [str ]:
@@ -70,28 +84,21 @@ def validate_names(cls, names: Iterable[str]):
70
84
"""
71
85
return cls .get_cache ()._validate_names (names )
72
86
73
- def __init__ (self , user_client : UserClient , token_client : TokenClient ):
74
- self .user_client : UserClient = user_client
75
- self .token_client : TokenClient = token_client
76
- self .map_id_to_name : Dict [str , str ] = {}
77
- self .map_name_to_id : Dict [str , str ] = {}
78
- self .map_email_to_id : Dict [str , str ] = {}
79
- self .lock : Lock = Lock ()
80
-
81
87
def _refresh_cache (self ) -> None :
82
88
with self .lock :
83
- users = self .user_client .get_all ()
84
- if users is not None :
85
- self .map_id_to_name = {}
86
- self .map_name_to_id = {}
87
- self .map_email_to_id = {}
88
- for user in users :
89
- user_id = str (user .id )
90
- username = str (user .username )
91
- user_email = str (user .email )
92
- self .map_id_to_name [user_id ] = username
93
- self .map_name_to_id [username ] = user_id
94
- self .map_email_to_id [user_email ] = user_id
89
+ users = self .client .user .get_all ()
90
+ if not users :
91
+ return
92
+ self .map_id_to_name = {}
93
+ self .map_name_to_id = {}
94
+ self .map_email_to_id = {}
95
+ for user in users :
96
+ user_id = str (user .id )
97
+ username = str (user .username )
98
+ user_email = str (user .email )
99
+ self .map_id_to_name [user_id ] = username
100
+ self .map_name_to_id [username ] = user_id
101
+ self .map_email_to_id [user_email ] = user_id
95
102
96
103
def _get_id_for_name (self , name : str ) -> Optional [str ]:
97
104
"""
@@ -105,7 +112,7 @@ def _get_id_for_name(self, name: str) -> Optional[str]:
105
112
# If we are translating an API token,
106
113
# short-circuit any further cache refresh
107
114
if name .startswith (SERVICE_ACCOUNT_ ):
108
- token = self .token_client .get_by_id (client_id = name )
115
+ token = self .client . token .get_by_id (client_id = name )
109
116
if token and token .guid :
110
117
self .map_name_to_id [name ] = token .guid
111
118
return token .guid
@@ -138,7 +145,7 @@ def _get_name_for_id(self, idstr: str) -> Optional[str]:
138
145
if username := self .map_id_to_name .get (idstr ):
139
146
return username
140
147
# If the username isn't found, check if it is an API token
141
- token = self .token_client .get_by_guid (guid = idstr )
148
+ token = self .client . token .get_by_guid (guid = idstr )
142
149
if token and token .client_id :
143
150
return token .username
144
151
else :
@@ -152,7 +159,7 @@ def _validate_names(self, names: Iterable[str]):
152
159
:param names: a collection of usernames to be checked
153
160
"""
154
161
for username in names :
155
- if not self .get_id_for_name (username ) and not self .token_client .get_by_id (
162
+ if not self .get_id_for_name (username ) and not self .client . token .get_by_id (
156
163
username
157
164
):
158
165
raise ValueError (
0 commit comments