|
| 1 | +import sys |
| 2 | +import os |
| 3 | +import threading |
| 4 | +import time |
| 5 | +from concurrent.futures import ThreadPoolExecutor |
| 6 | + |
| 7 | + |
| 8 | +from airbyte_cdk.sources.declarative.request_local.request_local import RequestLocal |
| 9 | + |
| 10 | +STREAM_SLICE_KEY = "stream_slice" |
| 11 | +INSTANCE_ID_KEY = "instance_id" |
| 12 | + |
| 13 | +def test_basic_singleton(): |
| 14 | + """Test basic singleton behavior""" |
| 15 | + # Multiple instantiations return same instance |
| 16 | + instance1 = RequestLocal() |
| 17 | + instance2 = RequestLocal() |
| 18 | + instance3 = RequestLocal() |
| 19 | + |
| 20 | + assert instance1 is instance2 |
| 21 | + assert instance1 is instance3, "All instances should be the same singleton instance" |
| 22 | + assert instance2 is instance3, "All instances should be the same singleton instance" |
| 23 | + |
| 24 | + |
| 25 | + # get_instance class method |
| 26 | + instance4 = RequestLocal.get_instance() |
| 27 | + instance1.stream_slice = {"test": "data"} |
| 28 | + |
| 29 | + # stream_slice property |
| 30 | + instance1.stream_slice = {"test": "data"} |
| 31 | + assert instance1.stream_slice is instance4.stream_slice |
| 32 | + assert instance2.stream_slice is instance4.stream_slice |
| 33 | + |
| 34 | + return instance1 |
| 35 | + |
| 36 | + |
| 37 | +def create_instance_in_thread(thread_id, results): |
| 38 | + """Function to create instance in a separate thread""" |
| 39 | + instance = RequestLocal() |
| 40 | + |
| 41 | + results[thread_id] = { |
| 42 | + 'instance_id': id(instance), |
| 43 | + 'thread_id': threading.get_ident() |
| 44 | + } |
| 45 | + time.sleep(0.1) # Small delay to ensure threads overlap |
| 46 | + |
| 47 | + |
| 48 | +def test_thread_safety(): |
| 49 | + """Ensure that RequestLocal is thread-safe and behaves as a singleton across threads""" |
| 50 | + print("\n=== Testing Thread Safety ===") |
| 51 | + |
| 52 | + results = {} |
| 53 | + threads = [] |
| 54 | + total_treads = 5 |
| 55 | + # Create multiple threads that instantiate RequestLocal |
| 56 | + for i in range(total_treads): |
| 57 | + thread = threading.Thread(target=create_instance_in_thread, args=(i, results)) |
| 58 | + threads.append(thread) |
| 59 | + thread.start() |
| 60 | + |
| 61 | + # Wait for all threads to complete |
| 62 | + for thread in threads: |
| 63 | + thread.join() |
| 64 | + |
| 65 | + # Analyze results |
| 66 | + instance_ids = [result[INSTANCE_ID_KEY] for result in results.values()] |
| 67 | + unique_ids = set(instance_ids) |
| 68 | + |
| 69 | + assert len(results) == total_treads, "All threads should have created an instance" |
| 70 | + assert len(unique_ids) == 1, "All threads should see the same singleton instance" |
| 71 | + |
| 72 | + |
| 73 | + |
| 74 | +def test_threading_local_behavior(): |
| 75 | + """Test how threading.local affects the singleton""" |
| 76 | + def thread_func(thread_name, shared_results, time_sleep): |
| 77 | + instance = RequestLocal() |
| 78 | + assert instance.stream_slice == None, "Initial stream_slice should be empty" |
| 79 | + instance.stream_slice = {f"data_from_{thread_name}": True} |
| 80 | + |
| 81 | + shared_results[thread_name] = { |
| 82 | + 'instance_id': id(instance), |
| 83 | + 'stream_slice': instance.stream_slice.copy(), |
| 84 | + 'thread_id': threading.get_ident() |
| 85 | + } |
| 86 | + |
| 87 | + # Check if we can see data from other threads |
| 88 | + # this should not happen as RequestLocal is a singleton |
| 89 | + time.sleep(time_sleep) |
| 90 | + shared_results[f"{thread_name}_after_sleep"] = { |
| 91 | + 'instance_id': id(instance), |
| 92 | + 'stream_slice': instance.stream_slice.copy(), |
| 93 | + 'end_time': time.time(), |
| 94 | + } |
| 95 | + |
| 96 | + results = {} |
| 97 | + threads = {} |
| 98 | + threads_amount = 3 |
| 99 | + time_sleep = 0.9 |
| 100 | + thread_names = [] |
| 101 | + for i in range(threads_amount): |
| 102 | + tread_name = f"thread_{i}" |
| 103 | + thread_names.append(tread_name) |
| 104 | + thread = threading.Thread(target=thread_func, args=(tread_name, results, time_sleep)) |
| 105 | + time_sleep /=3 # Decrease sleep time for each thread to ensure they overlap |
| 106 | + threads[tread_name]= thread |
| 107 | + thread.start() |
| 108 | + |
| 109 | + for _, thread in threads.items(): |
| 110 | + thread.join() |
| 111 | + |
| 112 | + end_times = [results[thread_name + "_after_sleep"]['end_time'] for thread_name in thread_names] |
| 113 | + last_end_time = end_times.pop() |
| 114 | + while end_times: |
| 115 | + current_end_time = end_times.pop() |
| 116 | + # Just checking the last thread created ended before the previous ones |
| 117 | + # so we could ensure the first thread created that sleep for a longer time |
| 118 | + # was not affected by the other threads |
| 119 | + assert last_end_time < current_end_time, "End times should be in increasing order" |
| 120 | + last_end_time = current_end_time |
| 121 | + |
| 122 | + assert len(thread_names) > 1 |
| 123 | + assert len(set(thread_names)) == len(thread_names), "Thread names should be unique" |
| 124 | + for curren_thread_name in thread_names: |
| 125 | + current_thread_name_after_sleep = f"{curren_thread_name}_after_sleep" |
| 126 | + assert results[curren_thread_name][STREAM_SLICE_KEY] == results[current_thread_name_after_sleep][STREAM_SLICE_KEY], \ |
| 127 | + f"Stream slice should remain consistent across thread {curren_thread_name} before and after sleep" |
| 128 | + assert results[curren_thread_name][INSTANCE_ID_KEY] == results[current_thread_name_after_sleep][INSTANCE_ID_KEY], \ |
| 129 | + f"Instance ID should remain consistent across thread {curren_thread_name} before and after sleep" |
| 130 | + |
| 131 | + # Check if stream slices are different across threads |
| 132 | + # but same instance ID |
| 133 | + for other_tread_name in [thread_name for thread_name in thread_names if thread_name != curren_thread_name]: |
| 134 | + assert results[curren_thread_name][STREAM_SLICE_KEY] != results[other_tread_name][STREAM_SLICE_KEY], \ |
| 135 | + f"Stream slices from different threads should not be the same: {curren_thread_name} vs {other_tread_name}" |
| 136 | + assert results[curren_thread_name][INSTANCE_ID_KEY] == results[other_tread_name][INSTANCE_ID_KEY] |
| 137 | + |
| 138 | +# Fixme: Uncomment this test put asserts and remove prints to test concurrent access |
| 139 | +# def test_concurrent_access(): |
| 140 | +# """Test concurrent access using ThreadPoolExecutor""" |
| 141 | +# print("\n=== Testing Concurrent Access ===") |
| 142 | +# |
| 143 | +# def worker(worker_id): |
| 144 | +# instance = RequestLocal() |
| 145 | +# return { |
| 146 | +# 'worker_id': worker_id, |
| 147 | +# 'instance_id': id(instance), |
| 148 | +# 'thread_id': threading.get_ident() |
| 149 | +# } |
| 150 | +# |
| 151 | +# with ThreadPoolExecutor(max_workers=10) as executor: |
| 152 | +# futures = [executor.submit(worker, i) for i in range(20)] |
| 153 | +# results = [future.result() for future in futures] |
| 154 | +# |
| 155 | +# # Analyze results |
| 156 | +# instance_ids = [result[INSTANCE_ID_KEY] for result in results] |
| 157 | +# unique_ids = set(instance_ids) |
| 158 | +# |
| 159 | +# print(f"Total workers: {len(results)}") |
| 160 | +# print(f"Unique instance IDs: {len(unique_ids)}") |
| 161 | +# print(f"Singleton behavior maintained: {len(unique_ids) == 1}") |
| 162 | +# |
| 163 | +# # Show first few results |
| 164 | +# print("First 5 results:") |
| 165 | +# for result in results[:5]: |
| 166 | +# print(f" Worker {result['worker_id']}: ID={result[INSTANCE_ID_KEY]}") |
| 167 | + |
0 commit comments