1
1
# SPDX-License-Identifier: Apache-2.0
2
- # Copyright 2024 Atlan Pte. Ltd.
2
+ # Copyright 2025 Atlan Pte. Ltd.
3
3
from __future__ import annotations
4
4
5
5
import logging
6
6
import threading
7
- from typing import Dict , Union
7
+ from threading import local
8
+ from typing import TYPE_CHECKING , Optional , Union
8
9
9
10
from pyatlan .cache .abstract_asset_cache import AbstractAssetCache , AbstractAssetName
10
11
from pyatlan .cache .connection_cache import ConnectionCache , ConnectionName
11
- from pyatlan .client .atlan import AtlanClient
12
12
from pyatlan .errors import AtlanError
13
13
from pyatlan .model .assets import Asset , Tag
14
14
from pyatlan .model .fluent_search import FluentSearch
15
15
from pyatlan .model .search import Term
16
16
17
- LOGGER = logging .getLogger (__name__ )
17
+ if TYPE_CHECKING :
18
+ from pyatlan .client .atlan import AtlanClient
18
19
19
20
lock = threading .Lock ()
21
+ source_tag_cache_tls = local () # Thread-local storage (TLS)
22
+ LOGGER = logging .getLogger (__name__ )
20
23
21
24
22
25
class SourceTagCache (AbstractAssetCache ):
@@ -34,21 +37,26 @@ class SourceTagCache(AbstractAssetCache):
34
37
35
38
_SEARCH_FIELDS = [Asset .NAME ]
36
39
SEARCH_ATTRIBUTES = [field .atlan_field_name for field in _SEARCH_FIELDS ]
37
- caches : Dict [int , SourceTagCache ] = dict ()
38
40
39
41
def __init__ (self , client : AtlanClient ):
40
42
super ().__init__ (client )
41
43
42
44
@classmethod
43
- def get_cache (cls ) -> SourceTagCache :
45
+ def get_cache (cls , client : Optional [ AtlanClient ] = None ) -> SourceTagCache :
44
46
from pyatlan .client .atlan import AtlanClient
45
47
46
48
with lock :
47
- default_client = AtlanClient .get_default_client ()
48
- cache_key = default_client .cache_key
49
- if cache_key not in cls .caches :
50
- cls .caches [cache_key ] = SourceTagCache (client = default_client )
51
- return cls .caches [cache_key ]
49
+ client = client or AtlanClient .get_default_client ()
50
+ cache_key = client .cache_key
51
+
52
+ if not hasattr (source_tag_cache_tls , "caches" ):
53
+ source_tag_cache_tls .caches = {}
54
+
55
+ if cache_key not in source_tag_cache_tls .caches :
56
+ cache_instance = SourceTagCache (client = client )
57
+ source_tag_cache_tls .caches [cache_key ] = cache_instance
58
+
59
+ return source_tag_cache_tls .caches [cache_key ]
52
60
53
61
@classmethod
54
62
def get_by_guid (cls , guid : str , allow_refresh : bool = True ) -> Tag :
0 commit comments