Skip to content

Commit 12f8675

Browse files
committed
add new local class to add data like stream_slice
1 parent 3f42896 commit 12f8675

File tree

6 files changed

+264
-6
lines changed

6 files changed

+264
-6
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .request_local import RequestLocal
2+
3+
__all__ = ["RequestLocal"]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from threading import local, Lock
2+
3+
class RequestLocal(local):
4+
_instance = None
5+
_lock = Lock() # Thread-safe singleton creation
6+
7+
def __new__(cls, *args, **kwargs):
8+
# Use double-checked locking for thread safety
9+
if cls._instance is None:
10+
with cls._lock:
11+
if cls._instance is None:
12+
cls._instance = super(RequestLocal, cls).__new__(cls)
13+
return cls._instance
14+
15+
def __init__(self):
16+
# __init__ will be called every time the class is instantiated,
17+
# but the object itself is only created once by __new__.
18+
# Use a flag to prevent re-initialization
19+
if not hasattr(self, '_initialized'):
20+
self._stream_slice = None # Initialize _stream_slice
21+
self._initialized = True
22+
23+
@property
24+
def stream_slice(self):
25+
return self._stream_slice
26+
27+
@stream_slice.setter
28+
def stream_slice(self, stream_slice):
29+
self._stream_slice = stream_slice
30+
31+
@classmethod
32+
def get_instance(cls):
33+
"""
34+
Get the singleton instance of RequestLocal.
35+
This is the recommended way to get the instance.
36+
"""
37+
if cls._instance is None:
38+
cls._instance = cls()
39+
return cls._instance

airbyte_cdk/sources/declarative/requesters/http_requester.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
combine_mappings,
3232
get_interpolation_context,
3333
)
34+
from airbyte_cdk.sources.declarative.request_local import RequestLocal
3435

3536

3637
@dataclass
@@ -449,6 +450,9 @@ def send_request(
449450
request_body_json: Optional[Mapping[str, Any]] = None,
450451
log_formatter: Optional[Callable[[requests.Response], Any]] = None,
451452
) -> Optional[requests.Response]:
453+
request_local = RequestLocal()
454+
request_local.stream_slice = stream_slice
455+
452456
request, response = self._http_client.send_request(
453457
http_method=self.get_method().value,
454458
url=self._get_url(

airbyte_cdk/sources/streams/http/rate_limiting.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
RateLimitBackoffException,
1818
UserDefinedBackoffException,
1919
)
20+
from airbyte_cdk.sources.declarative.request_local import RequestLocal
2021

2122
TRANSIENT_EXCEPTIONS = (
2223
DefaultBackoffException,
@@ -120,8 +121,9 @@ def sleep_on_ratelimit(details: Mapping[str, Any]) -> None:
120121
logging_message = (
121122
f"Retrying. Sleeping for {retry_after} seconds at {ab_datetime_now()} UTC"
122123
)
123-
if stream_slice:
124-
logging_message += f" for slice: {stream_slice}"
124+
request_local = RequestLocal()
125+
if request_local.stream_slice:
126+
logging_message += f" for slice: {request_local.stream_slice}"
125127
logger.info(logging_message)
126128
time.sleep(retry_after + 1) # extra second to cover any fractions of second
127129

@@ -156,9 +158,14 @@ def log_retry_attempt(details: Mapping[str, Any]) -> None:
156158
logger.info(
157159
f"Status code: {exc.response.status_code!r}, Response Content: {exc.response.content!r}"
158160
)
159-
logger.info(
160-
f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
161+
logger_slice_info = ""
162+
request_local = RequestLocal()
163+
if request_local.stream_slice:
164+
logger_slice_info = f" for slice: {request_local.stream_slice}"
165+
logger_info_message = (
166+
f"Caught retryable error '{str(exc)}' after {details['tries']} tries. Waiting {details['wait']} seconds then retrying{logger_slice_info}..."
161167
)
168+
logger.info(logger_info_message)
162169

163170
return backoff.on_exception( # type: ignore # Decorator function returns a function with a different signature than the input function, so mypy can't infer the type of the returned function
164171
backoff.expo,
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import sys
2+
import os
3+
import threading
4+
import time
5+
from concurrent.futures import ThreadPoolExecutor
6+
7+
8+
from airbyte_cdk.sources.declarative.request_local.request_local import RequestLocal
9+
10+
STREAM_SLICE_KEY = "stream_slice"
11+
INSTANCE_ID_KEY = "instance_id"
12+
13+
def test_basic_singleton():
14+
"""Test basic singleton behavior"""
15+
# Multiple instantiations return same instance
16+
instance1 = RequestLocal()
17+
instance2 = RequestLocal()
18+
instance3 = RequestLocal()
19+
20+
assert instance1 is instance2
21+
assert instance1 is instance3, "All instances should be the same singleton instance"
22+
assert instance2 is instance3, "All instances should be the same singleton instance"
23+
24+
25+
# get_instance class method
26+
instance4 = RequestLocal.get_instance()
27+
instance1.stream_slice = {"test": "data"}
28+
29+
# stream_slice property
30+
instance1.stream_slice = {"test": "data"}
31+
assert instance1.stream_slice is instance4.stream_slice
32+
assert instance2.stream_slice is instance4.stream_slice
33+
34+
return instance1
35+
36+
37+
def create_instance_in_thread(thread_id, results):
38+
"""Function to create instance in a separate thread"""
39+
instance = RequestLocal()
40+
41+
results[thread_id] = {
42+
'instance_id': id(instance),
43+
'thread_id': threading.get_ident()
44+
}
45+
time.sleep(0.1) # Small delay to ensure threads overlap
46+
47+
48+
def test_thread_safety():
49+
"""Ensure that RequestLocal is thread-safe and behaves as a singleton across threads"""
50+
print("\n=== Testing Thread Safety ===")
51+
52+
results = {}
53+
threads = []
54+
total_treads = 5
55+
# Create multiple threads that instantiate RequestLocal
56+
for i in range(total_treads):
57+
thread = threading.Thread(target=create_instance_in_thread, args=(i, results))
58+
threads.append(thread)
59+
thread.start()
60+
61+
# Wait for all threads to complete
62+
for thread in threads:
63+
thread.join()
64+
65+
# Analyze results
66+
instance_ids = [result[INSTANCE_ID_KEY] for result in results.values()]
67+
unique_ids = set(instance_ids)
68+
69+
assert len(results) == total_treads, "All threads should have created an instance"
70+
assert len(unique_ids) == 1, "All threads should see the same singleton instance"
71+
72+
73+
74+
def test_threading_local_behavior():
75+
"""Test how threading.local affects the singleton"""
76+
def thread_func(thread_name, shared_results, time_sleep):
77+
instance = RequestLocal()
78+
assert instance.stream_slice == None, "Initial stream_slice should be empty"
79+
instance.stream_slice = {f"data_from_{thread_name}": True}
80+
81+
shared_results[thread_name] = {
82+
'instance_id': id(instance),
83+
'stream_slice': instance.stream_slice.copy(),
84+
'thread_id': threading.get_ident()
85+
}
86+
87+
# Check if we can see data from other threads
88+
# this should not happen as RequestLocal is a singleton
89+
time.sleep(time_sleep)
90+
shared_results[f"{thread_name}_after_sleep"] = {
91+
'instance_id': id(instance),
92+
'stream_slice': instance.stream_slice.copy(),
93+
'end_time': time.time(),
94+
}
95+
96+
results = {}
97+
threads = {}
98+
threads_amount = 3
99+
time_sleep = 0.9
100+
thread_names = []
101+
for i in range(threads_amount):
102+
tread_name = f"thread_{i}"
103+
thread_names.append(tread_name)
104+
thread = threading.Thread(target=thread_func, args=(tread_name, results, time_sleep))
105+
time_sleep /=3 # Decrease sleep time for each thread to ensure they overlap
106+
threads[tread_name]= thread
107+
thread.start()
108+
109+
for _, thread in threads.items():
110+
thread.join()
111+
112+
end_times = [results[thread_name + "_after_sleep"]['end_time'] for thread_name in thread_names]
113+
last_end_time = end_times.pop()
114+
while end_times:
115+
current_end_time = end_times.pop()
116+
# Just checking the last thread created ended before the previous ones
117+
# so we could ensure the first thread created that sleep for a longer time
118+
# was not affected by the other threads
119+
assert last_end_time < current_end_time, "End times should be in increasing order"
120+
last_end_time = current_end_time
121+
122+
assert len(thread_names) > 1
123+
assert len(set(thread_names)) == len(thread_names), "Thread names should be unique"
124+
for curren_thread_name in thread_names:
125+
current_thread_name_after_sleep = f"{curren_thread_name}_after_sleep"
126+
assert results[curren_thread_name][STREAM_SLICE_KEY] == results[current_thread_name_after_sleep][STREAM_SLICE_KEY], \
127+
f"Stream slice should remain consistent across thread {curren_thread_name} before and after sleep"
128+
assert results[curren_thread_name][INSTANCE_ID_KEY] == results[current_thread_name_after_sleep][INSTANCE_ID_KEY], \
129+
f"Instance ID should remain consistent across thread {curren_thread_name} before and after sleep"
130+
131+
# Check if stream slices are different across threads
132+
# but same instance ID
133+
for other_tread_name in [thread_name for thread_name in thread_names if thread_name != curren_thread_name]:
134+
assert results[curren_thread_name][STREAM_SLICE_KEY] != results[other_tread_name][STREAM_SLICE_KEY], \
135+
f"Stream slices from different threads should not be the same: {curren_thread_name} vs {other_tread_name}"
136+
assert results[curren_thread_name][INSTANCE_ID_KEY] == results[other_tread_name][INSTANCE_ID_KEY]
137+
138+
# Fixme: Uncomment this test put asserts and remove prints to test concurrent access
139+
# def test_concurrent_access():
140+
# """Test concurrent access using ThreadPoolExecutor"""
141+
# print("\n=== Testing Concurrent Access ===")
142+
#
143+
# def worker(worker_id):
144+
# instance = RequestLocal()
145+
# return {
146+
# 'worker_id': worker_id,
147+
# 'instance_id': id(instance),
148+
# 'thread_id': threading.get_ident()
149+
# }
150+
#
151+
# with ThreadPoolExecutor(max_workers=10) as executor:
152+
# futures = [executor.submit(worker, i) for i in range(20)]
153+
# results = [future.result() for future in futures]
154+
#
155+
# # Analyze results
156+
# instance_ids = [result[INSTANCE_ID_KEY] for result in results]
157+
# unique_ids = set(instance_ids)
158+
#
159+
# print(f"Total workers: {len(results)}")
160+
# print(f"Unique instance IDs: {len(unique_ids)}")
161+
# print(f"Singleton behavior maintained: {len(unique_ids) == 1}")
162+
#
163+
# # Show first few results
164+
# print("First 5 results:")
165+
# for result in results[:5]:
166+
# print(f" Worker {result['worker_id']}: ID={result[INSTANCE_ID_KEY]}")
167+

unit_tests/sources/declarative/requesters/test_http_requester.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
33
#
44

5-
from datetime import timedelta
5+
import logging
6+
67
from typing import Any, Mapping, Optional
78
from unittest import mock
89
from unittest.mock import MagicMock
@@ -35,7 +36,7 @@
3536
MovingWindowCallRatePolicy,
3637
Rate,
3738
)
38-
from airbyte_cdk.sources.streams.http.error_handlers.response_models import ResponseAction
39+
from airbyte_cdk.sources.streams.http.exceptions import RateLimitBackoffException
3940
from airbyte_cdk.sources.streams.http.exceptions import (
4041
RequestBodyException,
4142
UserDefinedBackoffException,
@@ -216,6 +217,30 @@ def create_requester(
216217
requester._http_client._session.send.return_value = req
217218
return requester
218219

220+
def create_requester_rate_limited(
221+
url_base: Optional[str] = None,
222+
parameters: Optional[Mapping[str, Any]] = {},
223+
config: Optional[Config] = None,
224+
path: Optional[str] = None,
225+
authenticator: Optional[DeclarativeAuthenticator] = None,
226+
error_handler: Optional[ErrorHandler] = None,
227+
) -> HttpRequester:
228+
requester = HttpRequester(
229+
name="name",
230+
url_base=url_base or "https://example.com",
231+
path=path or "deals",
232+
http_method=HttpMethod.GET,
233+
request_options_provider=None,
234+
authenticator=authenticator,
235+
error_handler=error_handler,
236+
config=config or {},
237+
parameters=parameters or {},
238+
)
239+
requester._http_client._session.send = MagicMock()
240+
req = requests.Response()
241+
req.status_code = 429 # Simulating a rate limit response
242+
requester._http_client._session.send.return_value = req
243+
return requester
219244

220245
def test_basic_send_request():
221246
options_provider = MagicMock()
@@ -229,6 +254,19 @@ def test_basic_send_request():
229254
assert sent_request.headers["my_header"] == "my_value"
230255
assert sent_request.body is None
231256

257+
@pytest.mark.usefixtures("mock_sleep")
258+
def test_send_request_rate_limited(caplog):
259+
options_provider = MagicMock()
260+
options_provider.get_request_headers.return_value = {"my_header": "my_value"}
261+
requester = create_requester_rate_limited()
262+
requester._request_options_provider = options_provider
263+
with caplog.at_level(logging.INFO, logger="airbyte"):
264+
with pytest.raises(RateLimitBackoffException):
265+
requester.send_request(stream_slice={"start": "2012"})
266+
267+
268+
logged_messages = [record.message for record in caplog.records]
269+
assert "Caught retryable error 'Too many requests.' after 1 tries. Waiting 1 seconds then retrying for slice: {'start': '2012'}..." in logged_messages
232270

233271
@pytest.mark.parametrize(
234272
"provider_data, provider_json, param_data, param_json, authenticator_data, authenticator_json, expected_exception, expected_body",

0 commit comments

Comments
 (0)