Skip to content

Commit 3c59dbf

Browse files
committed
Merge branch 'christo/sdm-cli-dockerfile' of https://github.com/airbytehq/airbyte-python-cdk into christo/sdm-cli-dockerfile
2 parents 79f28c0 + 99ed234 commit 3c59dbf

File tree

9 files changed

+273
-92
lines changed

9 files changed

+273
-92
lines changed

airbyte_cdk/cli/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ This guide explains how to install and use the Source Declarative Manifest (SDM)
1010
pipx install airbyte-cdk
1111
```
1212

13-
If you encounter an error related to a missing `distutils` module, very that you are running Python version `<=3.11` and try running:
13+
If you encounter an error related to a missing `distutils` module, verify that you are running Python version `<=3.11` and try running:
1414

1515
```bash
1616
python -m pipx install airbyte-cdk

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
239239
def _has_access_token_been_initialized(self) -> bool:
240240
return self._access_token is not None
241241

242-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
243-
self._token_expiry_date = self._parse_token_expiration_date(value)
242+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
243+
self._token_expiry_date = value
244244

245245
def get_assertion_name(self) -> str:
246246
return self.assertion_name

airbyte_cdk/sources/declarative/concurrent_declarative_source.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
from airbyte_cdk.sources.declarative.extractors.record_filter import (
2020
ClientSideIncrementalRecordFilterDecorator,
2121
)
22-
from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor
22+
from airbyte_cdk.sources.declarative.incremental import (
23+
ConcurrentPerPartitionCursor,
24+
GlobalSubstreamCursor,
25+
)
2326
from airbyte_cdk.sources.declarative.incremental.datetime_based_cursor import DatetimeBasedCursor
2427
from airbyte_cdk.sources.declarative.incremental.per_partition_with_global import (
2528
PerPartitionWithGlobalCursor,
@@ -361,7 +364,8 @@ def _group_streams(
361364
== DatetimeBasedCursorModel.__name__
362365
and hasattr(declarative_stream.retriever, "stream_slicer")
363366
and isinstance(
364-
declarative_stream.retriever.stream_slicer, PerPartitionWithGlobalCursor
367+
declarative_stream.retriever.stream_slicer,
368+
(GlobalSubstreamCursor, PerPartitionWithGlobalCursor),
365369
)
366370
):
367371
stream_state = self._connector_state_manager.get_stream_state(

airbyte_cdk/sources/declarative/parsers/model_to_component_factory.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,9 @@ def create_concurrent_cursor_from_perpartition_cursor(
14391439
stream_state = self.apply_stream_state_migrations(stream_state_migrations, stream_state)
14401440

14411441
# Per-partition state doesn't make sense for GroupingPartitionRouter, so force the global state
1442-
use_global_cursor = isinstance(partition_router, GroupingPartitionRouter)
1442+
use_global_cursor = isinstance(
1443+
partition_router, GroupingPartitionRouter
1444+
) or component_definition.get("global_substream_cursor", False)
14431445

14441446
# Return the concurrent cursor and state converter
14451447
return ConcurrentPerPartitionCursor(

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
130130
headers = self.get_refresh_request_headers()
131131
return headers if headers else None
132132

133-
def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
133+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
134134
"""
135135
Returns the refresh token and its expiration datetime
136136
@@ -148,6 +148,14 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
148148
# PRIVATE METHODS
149149
# ----------------
150150

151+
def _default_token_expiry_date(self) -> AirbyteDateTime:
152+
"""
153+
Returns the default token expiry date
154+
"""
155+
# 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
156+
default_token_expiry_duration_hours = 1 # 1 hour
157+
return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
158+
151159
def _wrap_refresh_token_exception(
152160
self, exception: requests.exceptions.RequestException
153161
) -> bool:
@@ -257,14 +265,10 @@ def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) ->
257265

258266
def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
259267
"""
260-
Return the expiration datetime of the refresh token
268+
Parse a string or integer token expiration date into a datetime object
261269
262270
:return: expiration datetime
263271
"""
264-
if not value and not self.token_has_expired():
265-
# No expiry token was provided but the previous one is not expired so it's fine
266-
return self.get_token_expiry_date()
267-
268272
if self.token_expiry_is_time_of_expiration:
269273
if not self.token_expiry_date_format:
270274
raise ValueError(
@@ -308,17 +312,30 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
308312
"""
309313
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
310314

311-
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
315+
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
312316
"""
313317
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
314318
319+
If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
320+
315321
Args:
316322
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
317323
318324
Returns:
319-
str: The extracted token_expiry_date.
325+
The extracted token_expiry_date or None if not found.
320326
"""
321-
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
327+
expires_in = self._find_and_get_value_from_response(
328+
response_data, self.get_expires_in_name()
329+
)
330+
if expires_in is not None:
331+
return self._parse_token_expiration_date(expires_in)
332+
333+
# expires_in is None
334+
existing_expiry_date = self.get_token_expiry_date()
335+
if existing_expiry_date and not self.token_has_expired():
336+
return existing_expiry_date
337+
338+
return self._default_token_expiry_date()
322339

323340
def _find_and_get_value_from_response(
324341
self,
@@ -344,7 +361,7 @@ def _find_and_get_value_from_response(
344361
"""
345362
if current_depth > max_depth:
346363
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
347-
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
364+
message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
348365
raise ResponseKeysMaxRecurtionReached(
349366
internal_message=message, message=message, failure_type=FailureType.config_error
350367
)
@@ -441,7 +458,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
441458
"""Expiration date of the access token"""
442459

443460
@abstractmethod
444-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
461+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
445462
"""Setter for access token expiration date"""
446463

447464
@abstractmethod

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def get_grant_type(self) -> str:
120120
def get_token_expiry_date(self) -> AirbyteDateTime:
121121
return self._token_expiry_date
122122

123-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
124-
self._token_expiry_date = self._parse_token_expiration_date(value)
123+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
124+
self._token_expiry_date = value
125125

126126
@property
127127
def token_expiry_is_time_of_expiration(self) -> bool:
@@ -316,26 +316,6 @@ def token_has_expired(self) -> bool:
316316
"""Returns True if the token is expired"""
317317
return ab_datetime_now() > self.get_token_expiry_date()
318318

319-
@staticmethod
320-
def get_new_token_expiry_date(
321-
access_token_expires_in: str,
322-
token_expiry_date_format: str | None = None,
323-
) -> AirbyteDateTime:
324-
"""
325-
Calculate the new token expiry date based on the provided expiration duration or format.
326-
327-
Args:
328-
access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format.
329-
token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None.
330-
331-
Returns:
332-
AirbyteDateTime: The calculated expiry date of the access token.
333-
"""
334-
if token_expiry_date_format:
335-
return ab_datetime_parse(access_token_expires_in)
336-
else:
337-
return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in))
338-
339319
def get_access_token(self) -> str:
340320
"""Retrieve new access and refresh token if the access token has expired.
341321
The new refresh token is persisted with the set_refresh_token function
@@ -346,16 +326,13 @@ def get_access_token(self) -> str:
346326
new_access_token, access_token_expires_in, new_refresh_token = (
347327
self.refresh_access_token()
348328
)
349-
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
350-
access_token_expires_in, self._token_expiry_date_format
351-
)
352329
self.access_token = new_access_token
353330
self.set_refresh_token(new_refresh_token)
354-
self.set_token_expiry_date(new_token_expiry_date)
331+
self.set_token_expiry_date(access_token_expires_in)
355332
self._emit_control_message()
356333
return self.access_token
357334

358-
def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
335+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
359336
"""
360337
Refreshes the access token by making a handled request and extracting the necessary token information.
361338

unit_tests/sources/declarative/auth/test_oauth.py

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def test_error_on_refresh_token_grant_without_refresh_token(self):
203203
grant_type="refresh_token",
204204
)
205205

206+
@freezegun.freeze_time("2022-01-01")
206207
def test_refresh_access_token(self, mocker):
207208
oauth = DeclarativeOauth2Authenticator(
208209
token_refresh_endpoint="{{ config['refresh_endpoint'] }}",
@@ -225,13 +226,15 @@ def test_refresh_access_token(self, mocker):
225226
resp, "json", return_value={"access_token": "access_token", "expires_in": 1000}
226227
)
227228
mocker.patch.object(requests, "request", side_effect=mock_request, autospec=True)
228-
token = oauth.refresh_access_token()
229+
access_token, token_expiry_date = oauth.refresh_access_token()
229230

230-
assert ("access_token", 1000) == token
231+
assert access_token == "access_token"
232+
assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)
231233

232234
filtered = filter_secrets("access_token")
233235
assert filtered == "****"
234236

237+
@freezegun.freeze_time("2022-01-01")
235238
def test_refresh_access_token_when_headers_provided(self, mocker):
236239
expected_headers = {
237240
"Authorization": "Bearer some_access_token",
@@ -256,9 +259,10 @@ def test_refresh_access_token_when_headers_provided(self, mocker):
256259
mocked_request = mocker.patch.object(
257260
requests, "request", side_effect=mock_request, autospec=True
258261
)
259-
token = oauth.refresh_access_token()
262+
access_token, token_expiry_date = oauth.refresh_access_token()
260263

261-
assert ("access_token", 1000) == token
264+
assert access_token == "access_token"
265+
assert token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)
262266

263267
assert mocked_request.call_args.kwargs["headers"] == expected_headers
264268

@@ -314,6 +318,7 @@ def test_initialize_declarative_oauth_with_token_expiry_date_as_timestamp(
314318
assert isinstance(oauth._token_expiry_date, AirbyteDateTime)
315319
assert oauth.get_token_expiry_date() == ab_datetime_parse(expected_date)
316320

321+
@freezegun.freeze_time("2022-01-01")
317322
def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_fetch_access_token(
318323
self,
319324
) -> None:
@@ -335,12 +340,65 @@ def test_given_no_access_token_but_expiry_in_the_future_when_refresh_token_then_
335340
url="https://refresh_endpoint.com/",
336341
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
337342
),
338-
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
343+
HttpResponse(
344+
body=json.dumps({"access_token": "new_access_token", "expires_in": 1000})
345+
),
339346
)
340347
oauth.get_access_token()
341348

342349
assert oauth.access_token == "new_access_token"
343-
assert oauth._token_expiry_date == expiry_date
350+
assert oauth._token_expiry_date == ab_datetime_now() + timedelta(seconds=1000)
351+
352+
@freezegun.freeze_time("2022-01-01")
353+
@pytest.mark.parametrize(
354+
"initial_expiry_date_delta, expected_new_expiry_date_delta, expected_access_token",
355+
[
356+
(timedelta(days=1), timedelta(days=1), "some_access_token"),
357+
(timedelta(days=-1), timedelta(hours=1), "new_access_token"),
358+
(None, timedelta(hours=1), "new_access_token"),
359+
],
360+
ids=[
361+
"initial_expiry_date_in_future",
362+
"initial_expiry_date_in_past",
363+
"no_initial_expiry_date",
364+
],
365+
)
366+
def test_no_expiry_date_provided_by_auth_server(
367+
self,
368+
initial_expiry_date_delta,
369+
expected_new_expiry_date_delta,
370+
expected_access_token,
371+
) -> None:
372+
initial_expiry_date = (
373+
ab_datetime_now().add(initial_expiry_date_delta).isoformat()
374+
if initial_expiry_date_delta
375+
else None
376+
)
377+
expected_new_expiry_date = ab_datetime_now().add(expected_new_expiry_date_delta)
378+
oauth = DeclarativeOauth2Authenticator(
379+
token_refresh_endpoint="https://refresh_endpoint.com/",
380+
client_id="some_client_id",
381+
client_secret="some_client_secret",
382+
token_expiry_date=initial_expiry_date,
383+
access_token_value="some_access_token",
384+
refresh_token="some_refresh_token",
385+
config={},
386+
parameters={},
387+
grant_type="client",
388+
)
389+
390+
with HttpMocker() as http_mocker:
391+
http_mocker.post(
392+
HttpRequest(
393+
url="https://refresh_endpoint.com/",
394+
body="grant_type=client&client_id=some_client_id&client_secret=some_client_secret&refresh_token=some_refresh_token",
395+
),
396+
HttpResponse(body=json.dumps({"access_token": "new_access_token"})),
397+
)
398+
oauth.get_access_token()
399+
400+
assert oauth.access_token == expected_access_token
401+
assert oauth._token_expiry_date == expected_new_expiry_date
344402

345403
@pytest.mark.parametrize(
346404
"expires_in_response, token_expiry_date_format",
@@ -443,6 +501,7 @@ def test_set_token_expiry_date_no_format(self, mocker, expires_in_response, next
443501
assert "access_token" == token
444502
assert oauth.get_token_expiry_date() == ab_datetime_parse(next_day)
445503

504+
@freezegun.freeze_time("2022-01-01")
446505
def test_profile_assertion(self, mocker):
447506
with HttpMocker() as http_mocker:
448507
jwt = JwtAuthenticator(
@@ -477,7 +536,7 @@ def test_profile_assertion(self, mocker):
477536

478537
token = oauth.refresh_access_token()
479538

480-
assert ("access_token", 1000) == token
539+
assert ("access_token", ab_datetime_now().add(timedelta(seconds=1000))) == token
481540

482541
filtered = filter_secrets("access_token")
483542
assert filtered == "****"

unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3449,3 +3449,48 @@ def test_semaphore_cleanup():
34493449
assert '{"id":"2"}' not in cursor._semaphore_per_partition
34503450
assert len(cursor._partition_parent_state_map) == 0 # All parent states should be popped
34513451
assert cursor._parent_state == {"parent": {"state": "state2"}} # Last parent state
3452+
3453+
3454+
def test_given_global_state_when_read_then_state_is_not_per_partition() -> None:
3455+
manifest = deepcopy(SUBSTREAM_MANIFEST)
3456+
manifest["definitions"]["post_comments_stream"]["incremental_sync"][
3457+
"global_substream_cursor"
3458+
] = True
3459+
manifest["streams"].remove({"$ref": "#/definitions/post_comment_votes_stream"})
3460+
record = {
3461+
"id": 9,
3462+
"post_id": 1,
3463+
"updated_at": COMMENT_10_UPDATED_AT,
3464+
}
3465+
mock_requests = [
3466+
(
3467+
f"https://api.example.com/community/posts?per_page=100&start_time={START_DATE}",
3468+
{
3469+
"posts": [
3470+
{"id": 1, "updated_at": POST_1_UPDATED_AT},
3471+
],
3472+
},
3473+
),
3474+
# Fetch the first page of comments for post 1
3475+
(
3476+
"https://api.example.com/community/posts/1/comments?per_page=100",
3477+
{
3478+
"comments": [record],
3479+
},
3480+
),
3481+
]
3482+
3483+
run_mocked_test(
3484+
mock_requests,
3485+
manifest,
3486+
CONFIG,
3487+
"post_comments",
3488+
{},
3489+
[record],
3490+
{
3491+
"lookback_window": 1,
3492+
"parent_state": {"posts": {"updated_at": "2024-01-30T00:00:00Z"}},
3493+
"state": {"updated_at": "2024-01-25T00:00:00Z"},
3494+
"use_global_cursor": True, # ensures that it is running the Concurrent CDK version as this is not populated in the declarative implementation
3495+
}, # this state does have per partition which would be under `states`
3496+
)

0 commit comments

Comments
 (0)