|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # Copyright 2022 Atlan Pte. Ltd.
|
| 3 | +import threading |
| 4 | +import time |
3 | 5 | from importlib.resources import read_text
|
4 | 6 | from json import load, loads
|
5 | 7 | from pathlib import Path
|
@@ -2072,6 +2074,62 @@ def test_atlan_call_api_server_error_messages_with_causes(
|
2072 | 2074 | client.asset.save(glossary)
|
2073 | 2075 |
|
2074 | 2076 |
|
| 2077 | +@pytest.mark.parametrize("thread_count", [3]) # Run with three threads |
| 2078 | +def test_atlan_client_tls(thread_count): |
| 2079 | + """Tests that AtlanClient instances remain isolated across multiple threads.""" |
| 2080 | + validation_results = {} |
| 2081 | + results_lock = threading.Lock() |
| 2082 | + |
| 2083 | + def _test_atlan_client_isolation(name, api_key1, api_key2, api_key3): |
| 2084 | + """Creates three AtlanClient instances within the same thread and verifies isolation.""" |
| 2085 | + # Instantiate three separate AtlanClient instances |
| 2086 | + client1 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key1) |
| 2087 | + time.sleep(0.2) |
| 2088 | + observed1 = client1.get_default_client().api_key # Should match api_key1 |
| 2089 | + |
| 2090 | + client2 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key2) |
| 2091 | + time.sleep(0.2) |
| 2092 | + observed2 = client2.get_default_client().api_key # Should match api_key2 |
| 2093 | + |
| 2094 | + client3 = AtlanClient(base_url="https://test.atlan.com", api_key=api_key3) |
| 2095 | + time.sleep(0.2) |
| 2096 | + observed3 = client3.get_default_client().api_key # Should match api_key3 |
| 2097 | + |
| 2098 | + # Store results in a thread-safe way |
| 2099 | + with results_lock: |
| 2100 | + validation_results[name] = (observed1, observed2, observed3) |
| 2101 | + |
| 2102 | + # Define unique API keys for each thread |
| 2103 | + api_keys = [ |
| 2104 | + ("API_KEY_1A", "API_KEY_1B", "API_KEY_1C"), |
| 2105 | + ("API_KEY_2A", "API_KEY_2B", "API_KEY_2C"), |
| 2106 | + ("API_KEY_3A", "API_KEY_3B", "API_KEY_3C"), |
| 2107 | + ] |
| 2108 | + |
| 2109 | + threads = [] |
| 2110 | + for i in range(thread_count): |
| 2111 | + thread = threading.Thread( |
| 2112 | + target=_test_atlan_client_isolation, |
| 2113 | + args=(f"thread{i + 1}", *api_keys[i]), |
| 2114 | + ) |
| 2115 | + threads.append(thread) |
| 2116 | + thread.start() |
| 2117 | + |
| 2118 | + # Wait for all threads to finish |
| 2119 | + for thread in threads: |
| 2120 | + thread.join() |
| 2121 | + |
| 2122 | + # Validate that each thread's clients retained their assigned API keys |
| 2123 | + for i in range(thread_count): |
| 2124 | + thread_name = f"thread{i + 1}" |
| 2125 | + expected_keys = api_keys[i] |
| 2126 | + |
| 2127 | + assert validation_results[thread_name] == expected_keys, ( |
| 2128 | + f"Clients were overwritten across threads! " |
| 2129 | + f"{thread_name} saw {validation_results[thread_name]} instead of {expected_keys}" |
| 2130 | + ) |
| 2131 | + |
| 2132 | + |
2075 | 2133 | class TestBatch:
|
2076 | 2134 | def test_init(self, mock_atlan_client):
|
2077 | 2135 | sut = Batch(client=mock_atlan_client, max_size=10)
|
|
0 commit comments