Skip to content

Commit 25fdc32

Browse files
committed
[test] Added test_atlan_client_isolation to verify thread-local storage for the default client
1 parent 9df8cc5 commit 25fdc32

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

tests/unit/test_client.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# Copyright 2022 Atlan Pte. Ltd.
3+
import threading
4+
import time
35
from importlib.resources import read_text
46
from json import load, loads
57
from pathlib import Path
@@ -2072,6 +2074,62 @@ def test_atlan_call_api_server_error_messages_with_causes(
20722074
client.asset.save(glossary)
20732075

20742076

2077+
@pytest.mark.parametrize("thread_count", [3]) # Run with three threads
2078+
def test_atlan_client_tls(thread_count):
2079+
"""Tests that AtlanClient instances remain isolated across multiple threads."""
2080+
validation_results = {}
2081+
results_lock = threading.Lock()
2082+
2083+
def _test_atlan_client_isolation(name, api_key1, api_key2, api_key3):
2084+
"""Creates three AtlanClient instances within the same thread and verifies isolation."""
2085+
# Instantiate three separate AtlanClient instances
2086+
client1 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key1)
2087+
time.sleep(0.2)
2088+
observed1 = client1.get_default_client().api_key # Should match api_key1
2089+
2090+
client2 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key2)
2091+
time.sleep(0.2)
2092+
observed2 = client2.get_default_client().api_key # Should match api_key2
2093+
2094+
client3 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key3)
2095+
time.sleep(0.2)
2096+
observed3 = client3.get_default_client().api_key # Should match api_key3
2097+
2098+
# Store results in a thread-safe way
2099+
with results_lock:
2100+
validation_results[name] = (observed1, observed2, observed3)
2101+
2102+
# Define unique API keys for each thread
2103+
api_keys = [
2104+
("API_KEY_1A", "API_KEY_1B", "API_KEY_1C"),
2105+
("API_KEY_2A", "API_KEY_2B", "API_KEY_2C"),
2106+
("API_KEY_3A", "API_KEY_3B", "API_KEY_3C"),
2107+
]
2108+
2109+
threads = []
2110+
for i in range(thread_count):
2111+
thread = threading.Thread(
2112+
target=_test_atlan_client_isolation,
2113+
args=(f"thread{i + 1}", *api_keys[i]),
2114+
)
2115+
threads.append(thread)
2116+
thread.start()
2117+
2118+
# Wait for all threads to finish
2119+
for thread in threads:
2120+
thread.join()
2121+
2122+
# Validate that each thread's clients retained their assigned API keys
2123+
for i in range(thread_count):
2124+
thread_name = f"thread{i + 1}"
2125+
expected_keys = api_keys[i]
2126+
2127+
assert validation_results[thread_name] == expected_keys, (
2128+
f"Clients were overwritten across threads! "
2129+
f"{thread_name} saw {validation_results[thread_name]} instead of {expected_keys}"
2130+
)
2131+
2132+
20752133
class TestBatch:
20762134
def test_init(self, mock_atlan_client):
20772135
sut = Batch(client=mock_atlan_client, max_size=10)

0 commit comments

Comments
 (0)