1
1
# SPDX-License-Identifier: Apache-2.0
2
- # Copyright 2024 Atlan Pte. Ltd.
2
+ # Copyright 2025 Atlan Pte. Ltd.
3
3
from __future__ import annotations
4
4
5
5
import logging
6
6
import threading
7
- from typing import Dict , Optional , Union
7
+ from threading import local
8
+ from typing import TYPE_CHECKING , Optional , Union
8
9
9
10
from pyatlan .cache .abstract_asset_cache import AbstractAssetCache , AbstractAssetName
10
- from pyatlan .client .atlan import AtlanClient
11
11
from pyatlan .model .assets import Asset , Connection
12
12
from pyatlan .model .enums import AtlanConnectorType
13
13
from pyatlan .model .fluent_search import FluentSearch
14
14
from pyatlan .model .search import Term
15
15
16
+ if TYPE_CHECKING :
17
+ from pyatlan .client .atlan import AtlanClient
16
18
LOGGER = logging .getLogger (__name__ )
17
19
18
20
lock = threading .Lock ()
21
+ connection_cache_tls = local () # Thread-local storage (TLS)
19
22
20
23
21
24
class ConnectionCache (AbstractAssetCache ):
@@ -37,21 +40,26 @@ class ConnectionCache(AbstractAssetCache):
37
40
Connection .CONNECTOR_NAME ,
38
41
]
39
42
SEARCH_ATTRIBUTES = [field .atlan_field_name for field in _SEARCH_FIELDS ]
40
- caches : Dict [int , ConnectionCache ] = dict ()
41
43
42
44
def __init__ (self , client : AtlanClient ):
43
45
super ().__init__ (client )
44
46
45
47
@classmethod
46
- def get_cache (cls ) -> ConnectionCache :
48
+ def get_cache (cls , client : Optional [ AtlanClient ] = None ) -> ConnectionCache :
47
49
from pyatlan .client .atlan import AtlanClient
48
50
49
51
with lock :
50
- default_client = AtlanClient .get_default_client ()
51
- cache_key = default_client .cache_key
52
- if cache_key not in cls .caches :
53
- cls .caches [cache_key ] = ConnectionCache (client = default_client )
54
- return cls .caches [cache_key ]
52
+ client = client or AtlanClient .get_default_client ()
53
+ cache_key = client .cache_key
54
+
55
+ if not hasattr (connection_cache_tls , "caches" ):
56
+ connection_cache_tls .caches = {}
57
+
58
+ if cache_key not in connection_cache_tls .caches :
59
+ cache_instance = ConnectionCache (client = client )
60
+ connection_cache_tls .caches [cache_key ] = cache_instance
61
+
62
+ return connection_cache_tls .caches [cache_key ]
55
63
56
64
@classmethod
57
65
def get_by_guid (cls , guid : str , allow_refresh : bool = True ) -> Connection :
@@ -139,21 +147,22 @@ def lookup_by_qualified_name(self, connection_qn: str) -> None:
139
147
def lookup_by_name (self , name : ConnectionName ) -> None :
140
148
if not isinstance (name , ConnectionName ):
141
149
return
142
- results = self .client .asset .find_connections_by_name (
143
- name = name .name ,
144
- connector_type = name .type ,
145
- attributes = self .SEARCH_ATTRIBUTES ,
146
- )
147
- if not results :
148
- return
149
- if len (results ) > 1 :
150
- LOGGER .warning (
151
- (
152
- "Found multiple connections of the same type with the same name, caching only the first: %s"
153
- ),
154
- name ,
150
+ with self .lock :
151
+ results = self .client .asset .find_connections_by_name (
152
+ name = name .name ,
153
+ connector_type = name .type ,
154
+ attributes = self .SEARCH_ATTRIBUTES ,
155
155
)
156
- self .cache (results [0 ])
156
+ if not results :
157
+ return
158
+ if len (results ) > 1 :
159
+ LOGGER .warning (
160
+ (
161
+ "Found multiple connections of the same type with the same name, caching only the first: %s"
162
+ ),
163
+ name ,
164
+ )
165
+ self .cache (results [0 ])
157
166
158
167
def get_name (self , asset : Asset ):
159
168
if not isinstance (asset , Connection ):
0 commit comments