Skip to content

Commit 84326a4

Browse files
committed
[test/refactor] Refactored connection and source tag unit tests
1 parent bbc60b9 commit 84326a4

File tree

6 files changed

+595
-569
lines changed

6 files changed

+595
-569
lines changed

pyatlan/cache/source_tag_cache.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import logging
66
import threading
7-
from threading import local
87
from typing import TYPE_CHECKING, Union
98

109
from pyatlan.cache.abstract_asset_cache import AbstractAssetCache, AbstractAssetName
@@ -18,7 +17,6 @@
1817
from pyatlan.client.atlan import AtlanClient
1918

2019
lock = threading.Lock()
21-
source_tag_cache_tls = local() # Thread-local storage (TLS)
2220
LOGGER = logging.getLogger(__name__)
2321

2422

pyatlan/client/atlan.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,6 @@ def __init__(self, **data):
224224
session.mount(HTTP_PREFIX, adapter)
225225
AtlanClient.set_current_client(self)
226226

227-
@property
228-
def cache_key(self) -> int:
229-
return f"{self.base_url}/{self.api_key}".__hash__()
230-
231227
@property
232228
def admin(self) -> AdminClient:
233229
if self._admin_client is None:
@@ -357,6 +353,7 @@ def contracts(self) -> ContractClient:
357353
@property
358354
def atlan_tag_cache(self) -> AtlanTagCache:
359355
if self._atlan_tag_cache is None:
356+
print("yes... calling cache")
360357
AtlanClient.set_current_client(self)
361358
self._atlan_tag_cache = AtlanTagCache(client=self)
362359
return self._atlan_tag_cache
@@ -698,7 +695,7 @@ def _handle_401_token_refresh(
698695
699696
returns: HTTP response received after retrying the request with the refreshed token
700697
"""
701-
new_token = self.get_current_client().impersonate.user(user_id=self._user_id)
698+
new_token = self.impersonate.user(user_id=self._user_id)
702699
self.api_key = new_token
703700
self._has_retried_for_401 = True
704701
params["headers"]["authorization"] = f"Bearer {self.api_key}"

pyatlan/model/custom_metadata.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,17 @@ def __init__(self, business_attributes: Optional[Dict[str, Any]]):
119119
if self._business_attributes is None:
120120
return
121121
self._metadata = {}
122-
self._client = AtlanClient.get_current_client()
122+
client = AtlanClient.get_current_client()
123123
for cm_id, cm_attributes in self._business_attributes.items():
124124
try:
125-
cm_name = self._client.custom_metadata_cache.get_name_for_id(cm_id)
125+
cm_name = client.custom_metadata_cache.get_name_for_id(cm_id)
126126
attribs = CustomMetadataDict(name=cm_name)
127127
for attr_id, properties in cm_attributes.items():
128-
attr_name = self._client.custom_metadata_cache.get_attr_name_for_id(
128+
attr_name = client.custom_metadata_cache.get_attr_name_for_id(
129129
cm_id, attr_id
130130
)
131131
# Only set active custom metadata attributes
132-
if not self._client.custom_metadata_cache.is_attr_archived(
132+
if not client.custom_metadata_cache.is_attr_archived(
133133
attr_id=attr_id
134134
):
135135
attribs[attr_name] = properties
@@ -164,8 +164,10 @@ def modified(self) -> bool:
164164
@property
165165
def business_attributes(self) -> Optional[Dict[str, Any]]:
166166
if self.modified and self._metadata is not None:
167+
from pyatlan.client.atlan import AtlanClient
168+
167169
return {
168-
self._client.custom_metadata_cache.get_id_for_name(
170+
AtlanClient.get_current_client().custom_metadata_cache.get_id_for_name(
169171
key
170172
): value.business_attributes
171173
for key, value in self._metadata.items()

tests/unit/test_client.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright 2022 Atlan Pte. Ltd.
3+
import threading
4+
import time
35
from importlib.resources import read_text
46
from json import load, loads
57
from pathlib import Path
@@ -2092,60 +2094,60 @@ def test_atlan_call_api_server_error_messages_with_causes(
20922094
client.asset.save(glossary)
20932095

20942096

2095-
# @pytest.mark.parametrize("thread_count", [3]) # Run with three threads
2096-
# def test_atlan_client_tls(thread_count):
2097-
# """Tests that AtlanClient instances remain isolated across multiple threads."""
2098-
# validation_results = {}
2099-
# results_lock = threading.Lock()
2100-
2101-
# def _test_atlan_client_isolation(name, api_key1, api_key2, api_key3):
2102-
# """Creates three AtlanClient instances within the same thread and verifies isolation."""
2103-
# # Instantiate three separate AtlanClient instances
2104-
# client1 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key1)
2105-
# time.sleep(0.2)
2106-
# observed1 = client1.get_current_client().api_key # Should match api_key1
2107-
2108-
# client2 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key2)
2109-
# time.sleep(0.2)
2110-
# observed2 = client2.get_current_client().api_key # Should match api_key2
2111-
2112-
# client3 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key3)
2113-
# time.sleep(0.2)
2114-
# observed3 = client3.get_current_client().api_key # Should match api_key3
2115-
2116-
# # Store results in a thread-safe way
2117-
# with results_lock:
2118-
# validation_results[name] = (observed1, observed2, observed3)
2119-
2120-
# # Define unique API keys for each thread
2121-
# api_keys = [
2122-
# ("API_KEY_1A", "API_KEY_1B", "API_KEY_1C"),
2123-
# ("API_KEY_2A", "API_KEY_2B", "API_KEY_2C"),
2124-
# ("API_KEY_3A", "API_KEY_3B", "API_KEY_3C"),
2125-
# ]
2126-
2127-
# threads = []
2128-
# for i in range(thread_count):
2129-
# thread = threading.Thread(
2130-
# target=_test_atlan_client_isolation,
2131-
# args=(f"thread{i + 1}", *api_keys[i]),
2132-
# )
2133-
# threads.append(thread)
2134-
# thread.start()
2135-
2136-
# # Wait for all threads to finish
2137-
# for thread in threads:
2138-
# thread.join()
2139-
2140-
# # Validate that each thread's clients retained their assigned API keys
2141-
# for i in range(thread_count):
2142-
# thread_name = f"thread{i + 1}"
2143-
# expected_keys = api_keys[i]
2144-
2145-
# assert validation_results[thread_name] == expected_keys, (
2146-
# f"Clients were overwritten across threads! "
2147-
# f"{thread_name} saw {validation_results[thread_name]} instead of {expected_keys}"
2148-
# )
2097+
@pytest.mark.parametrize("thread_count", [3]) # Run with three threads
2098+
def test_atlan_client_tls(thread_count):
2099+
"""Tests that AtlanClient instances remain isolated across multiple threads."""
2100+
validation_results = {}
2101+
results_lock = threading.Lock()
2102+
2103+
def _test_atlan_client_isolation(name, api_key1, api_key2, api_key3):
2104+
"""Creates three AtlanClient instances within the same thread and verifies isolation."""
2105+
# Instantiate three separate AtlanClient instances
2106+
client1 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key1)
2107+
time.sleep(0.2)
2108+
observed1 = client1.get_current_client().api_key # Should match api_key1
2109+
2110+
client2 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key2)
2111+
time.sleep(0.2)
2112+
observed2 = client2.get_current_client().api_key # Should match api_key2
2113+
2114+
client3 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key3)
2115+
time.sleep(0.2)
2116+
observed3 = client3.get_current_client().api_key # Should match api_key3
2117+
2118+
# Store results in a thread-safe way
2119+
with results_lock:
2120+
validation_results[name] = (observed1, observed2, observed3)
2121+
2122+
# Define unique API keys for each thread
2123+
api_keys = [
2124+
("API_KEY_1A", "API_KEY_1B", "API_KEY_1C"),
2125+
("API_KEY_2A", "API_KEY_2B", "API_KEY_2C"),
2126+
("API_KEY_3A", "API_KEY_3B", "API_KEY_3C"),
2127+
]
2128+
2129+
threads = []
2130+
for i in range(thread_count):
2131+
thread = threading.Thread(
2132+
target=_test_atlan_client_isolation,
2133+
args=(f"thread{i + 1}", *api_keys[i]),
2134+
)
2135+
threads.append(thread)
2136+
thread.start()
2137+
2138+
# Wait for all threads to finish
2139+
for thread in threads:
2140+
thread.join()
2141+
2142+
# Validate that each thread's clients retained their assigned API keys
2143+
for i in range(thread_count):
2144+
thread_name = f"thread{i + 1}"
2145+
expected_keys = api_keys[i]
2146+
2147+
assert validation_results[thread_name] == expected_keys, (
2148+
f"Clients were overwritten across threads! "
2149+
f"{thread_name} saw {validation_results[thread_name]} instead of {expected_keys}"
2150+
)
21492151

21502152

21512153
class TestBatch:

0 commit comments

Comments
 (0)