Skip to content

Commit 278854e

Browse files
committed
[change] Bound RoleCache to AtlanClient, moved class var to TLS
1 parent 745e6ac commit 278854e

File tree

2 files changed

+50
-35
lines changed

2 files changed

+50
-35
lines changed

pyatlan/cache/user_cache.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,49 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright 2022 Atlan Pte. Ltd.
3-
from threading import Lock
4-
from typing import Dict, Iterable, Optional
3+
from __future__ import annotations
4+
5+
from threading import Lock, local
6+
from typing import TYPE_CHECKING, Dict, Iterable, Optional
57

6-
from pyatlan.client.token import TokenClient
7-
from pyatlan.client.user import UserClient
88
from pyatlan.errors import ErrorCode
99
from pyatlan.model.constants import SERVICE_ACCOUNT_
1010

11+
if TYPE_CHECKING:
12+
from pyatlan.client.atlan import AtlanClient
13+
1114
lock = Lock()
15+
user_cache_tls = local() # Thread-local storage (TLS)
1216

1317

1418
class UserCache:
1519
"""
1620
Lazily-loaded cache for translating Atlan-internal users into their various IDs.
1721
"""
1822

19-
caches: Dict[int, "UserCache"] = {}
23+
def __init__(self, client: AtlanClient):
24+
self.client: AtlanClient = client
25+
self.map_id_to_name: Dict[str, str] = {}
26+
self.map_name_to_id: Dict[str, str] = {}
27+
self.map_email_to_id: Dict[str, str] = {}
28+
self.lock: Lock = Lock()
2029

2130
@classmethod
22-
def get_cache(cls) -> "UserCache":
31+
def get_cache(cls, client: Optional[AtlanClient] = None) -> UserCache:
2332
from pyatlan.client.atlan import AtlanClient
2433

2534
with lock:
26-
client = AtlanClient.get_default_client()
35+
client = client or AtlanClient.get_default_client()
2736
cache_key = client.cache_key
28-
if cache_key not in cls.caches:
29-
cls.caches[cache_key] = UserCache(
30-
user_client=client.user, token_client=client.token
31-
)
32-
return cls.caches[cache_key]
37+
38+
if not hasattr(user_cache_tls, "caches"):
39+
user_cache_tls.caches = {}
40+
41+
if cache_key not in user_cache_tls.caches:
42+
cache_instance = UserCache(client=client)
43+
cache_instance._refresh_cache() # Refresh on new cache instance
44+
user_cache_tls.caches[cache_key] = cache_instance
45+
46+
return user_cache_tls.caches[cache_key]
3347

3448
@classmethod
3549
def get_id_for_name(cls, name: str) -> Optional[str]:
@@ -70,28 +84,21 @@ def validate_names(cls, names: Iterable[str]):
7084
"""
7185
return cls.get_cache()._validate_names(names)
7286

73-
def __init__(self, user_client: UserClient, token_client: TokenClient):
74-
self.user_client: UserClient = user_client
75-
self.token_client: TokenClient = token_client
76-
self.map_id_to_name: Dict[str, str] = {}
77-
self.map_name_to_id: Dict[str, str] = {}
78-
self.map_email_to_id: Dict[str, str] = {}
79-
self.lock: Lock = Lock()
80-
8187
def _refresh_cache(self) -> None:
8288
with self.lock:
83-
users = self.user_client.get_all()
84-
if users is not None:
85-
self.map_id_to_name = {}
86-
self.map_name_to_id = {}
87-
self.map_email_to_id = {}
88-
for user in users:
89-
user_id = str(user.id)
90-
username = str(user.username)
91-
user_email = str(user.email)
92-
self.map_id_to_name[user_id] = username
93-
self.map_name_to_id[username] = user_id
94-
self.map_email_to_id[user_email] = user_id
89+
users = self.client.user.get_all()
90+
if not users:
91+
return
92+
self.map_id_to_name = {}
93+
self.map_name_to_id = {}
94+
self.map_email_to_id = {}
95+
for user in users:
96+
user_id = str(user.id)
97+
username = str(user.username)
98+
user_email = str(user.email)
99+
self.map_id_to_name[user_id] = username
100+
self.map_name_to_id[username] = user_id
101+
self.map_email_to_id[user_email] = user_id
95102

96103
def _get_id_for_name(self, name: str) -> Optional[str]:
97104
"""
@@ -105,7 +112,7 @@ def _get_id_for_name(self, name: str) -> Optional[str]:
105112
# If we are translating an API token,
106113
# short-circuit any further cache refresh
107114
if name.startswith(SERVICE_ACCOUNT_):
108-
token = self.token_client.get_by_id(client_id=name)
115+
token = self.client.token.get_by_id(client_id=name)
109116
if token and token.guid:
110117
self.map_name_to_id[name] = token.guid
111118
return token.guid
@@ -138,7 +145,7 @@ def _get_name_for_id(self, idstr: str) -> Optional[str]:
138145
if username := self.map_id_to_name.get(idstr):
139146
return username
140147
# If the username isn't found, check if it is an API token
141-
token = self.token_client.get_by_guid(guid=idstr)
148+
token = self.client.token.get_by_guid(guid=idstr)
142149
if token and token.client_id:
143150
return token.username
144151
else:
@@ -152,7 +159,7 @@ def _validate_names(self, names: Iterable[str]):
152159
:param names: a collection of usernames to be checked
153160
"""
154161
for username in names:
155-
if not self.get_id_for_name(username) and not self.token_client.get_by_id(
162+
if not self.get_id_for_name(username) and not self.client.token.get_by_id(
156163
username
157164
):
158165
raise ValueError(

pyatlan/client/atlan.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from pyatlan.cache.enum_cache import EnumCache
4747
from pyatlan.cache.group_cache import GroupCache
4848
from pyatlan.cache.role_cache import RoleCache
49+
from pyatlan.cache.user_cache import UserCache
4950
from pyatlan.client.admin import AdminClient
5051
from pyatlan.client.asset import A, AssetClient, IndexSearchResults, LineageListResults
5152
from pyatlan.client.audit import AuditClient
@@ -179,6 +180,7 @@ class AtlanClient(BaseSettings):
179180
_enum_cache: Optional[EnumCache] = PrivateAttr(default=None)
180181
_group_cache: Optional[GroupCache] = PrivateAttr(default=None)
181182
_role_cache: Optional[RoleCache] = PrivateAttr(default=None)
183+
_user_cache: Optional[UserCache] = PrivateAttr(default=None)
182184
_custom_metadata_cache: Optional[CustomMetadataCache] = PrivateAttr(default=None)
183185

184186
class Config:
@@ -353,6 +355,12 @@ def role_cache(self) -> RoleCache:
353355
self._role_cache = RoleCache.get_cache(client=self)
354356
return self._role_cache
355357

358+
@property
359+
def user_cache(self) -> UserCache:
360+
if self._user_cache is None:
361+
self._user_cache = UserCache.get_cache(client=self)
362+
return self._user_cache
363+
356364
@property
357365
def custom_metadata_cache(self) -> CustomMetadataCache:
358366
if self._custom_metadata_cache is None:

0 commit comments

Comments
 (0)