Skip to content

Commit 3859c5b

Browse files
author
Oleksandr Bazarnov
committed
Merge remote-tracking branch 'origin/main' into baz/cdk/extract-common-manifest-parts
2 parents d929167 + 24cbc51 commit 3859c5b

37 files changed

+1097
-115
lines changed

airbyte_cdk/sources/declarative/auth/oauth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,8 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
239239
def _has_access_token_been_initialized(self) -> bool:
240240
return self._access_token is not None
241241

242-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
243-
self._token_expiry_date = self._parse_token_expiration_date(value)
242+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
243+
self._token_expiry_date = value
244244

245245
def get_assertion_name(self) -> str:
246246
return self.assertion_name

airbyte_cdk/sources/streams/http/requests_native_auth/abstract_oauth.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def build_refresh_request_headers(self) -> Mapping[str, Any] | None:
130130
headers = self.get_refresh_request_headers()
131131
return headers if headers else None
132132

133-
def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
133+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime]:
134134
"""
135135
Returns the refresh token and its expiration datetime
136136
@@ -148,6 +148,14 @@ def refresh_access_token(self) -> Tuple[str, Union[str, int]]:
148148
# PRIVATE METHODS
149149
# ----------------
150150

151+
def _default_token_expiry_date(self) -> AirbyteDateTime:
152+
"""
153+
Returns the default token expiry date
154+
"""
155+
# 1 hour was chosen as a middle ground to avoid unnecessary frequent refreshes and token expiration
156+
default_token_expiry_duration_hours = 1 # 1 hour
157+
return ab_datetime_now() + timedelta(hours=default_token_expiry_duration_hours)
158+
151159
def _wrap_refresh_token_exception(
152160
self, exception: requests.exceptions.RequestException
153161
) -> bool:
@@ -257,14 +265,10 @@ def _ensure_access_token_in_response(self, response_data: Mapping[str, Any]) ->
257265

258266
def _parse_token_expiration_date(self, value: Union[str, int]) -> AirbyteDateTime:
259267
"""
260-
Return the expiration datetime of the refresh token
268+
Parse a string or integer token expiration date into a datetime object
261269
262270
:return: expiration datetime
263271
"""
264-
if not value and not self.token_has_expired():
265-
# No expiry token was provided but the previous one is not expired so it's fine
266-
return self.get_token_expiry_date()
267-
268272
if self.token_expiry_is_time_of_expiration:
269273
if not self.token_expiry_date_format:
270274
raise ValueError(
@@ -308,17 +312,30 @@ def _extract_refresh_token(self, response_data: Mapping[str, Any]) -> Any:
308312
"""
309313
return self._find_and_get_value_from_response(response_data, self.get_refresh_token_name())
310314

311-
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> Any:
315+
def _extract_token_expiry_date(self, response_data: Mapping[str, Any]) -> AirbyteDateTime:
312316
"""
313317
Extracts the token_expiry_date, like `expires_in` or `expires_at`, etc from the given response data.
314318
319+
If the token_expiry_date is not found, it will return an existing token expiry date if set, or a default token expiry date.
320+
315321
Args:
316322
response_data (Mapping[str, Any]): The response data from which to extract the token_expiry_date.
317323
318324
Returns:
319-
str: The extracted token_expiry_date.
325+
The extracted token_expiry_date or None if not found.
320326
"""
321-
return self._find_and_get_value_from_response(response_data, self.get_expires_in_name())
327+
expires_in = self._find_and_get_value_from_response(
328+
response_data, self.get_expires_in_name()
329+
)
330+
if expires_in is not None:
331+
return self._parse_token_expiration_date(expires_in)
332+
333+
# expires_in is None
334+
existing_expiry_date = self.get_token_expiry_date()
335+
if existing_expiry_date and not self.token_has_expired():
336+
return existing_expiry_date
337+
338+
return self._default_token_expiry_date()
322339

323340
def _find_and_get_value_from_response(
324341
self,
@@ -344,7 +361,7 @@ def _find_and_get_value_from_response(
344361
"""
345362
if current_depth > max_depth:
346363
# this is needed to avoid an inf loop, possible with a very deep nesting observed.
347-
message = f"The maximum level of recursion is reached. Couldn't find the speficied `{key_name}` in the response."
364+
message = f"The maximum level of recursion is reached. Couldn't find the specified `{key_name}` in the response."
348365
raise ResponseKeysMaxRecurtionReached(
349366
internal_message=message, message=message, failure_type=FailureType.config_error
350367
)
@@ -441,7 +458,7 @@ def get_token_expiry_date(self) -> AirbyteDateTime:
441458
"""Expiration date of the access token"""
442459

443460
@abstractmethod
444-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
461+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
445462
"""Setter for access token expiration date"""
446463

447464
@abstractmethod

airbyte_cdk/sources/streams/http/requests_native_auth/oauth.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ def get_grant_type(self) -> str:
120120
def get_token_expiry_date(self) -> AirbyteDateTime:
121121
return self._token_expiry_date
122122

123-
def set_token_expiry_date(self, value: Union[str, int]) -> None:
124-
self._token_expiry_date = self._parse_token_expiration_date(value)
123+
def set_token_expiry_date(self, value: AirbyteDateTime) -> None:
124+
self._token_expiry_date = value
125125

126126
@property
127127
def token_expiry_is_time_of_expiration(self) -> bool:
@@ -316,26 +316,6 @@ def token_has_expired(self) -> bool:
316316
"""Returns True if the token is expired"""
317317
return ab_datetime_now() > self.get_token_expiry_date()
318318

319-
@staticmethod
320-
def get_new_token_expiry_date(
321-
access_token_expires_in: str,
322-
token_expiry_date_format: str | None = None,
323-
) -> AirbyteDateTime:
324-
"""
325-
Calculate the new token expiry date based on the provided expiration duration or format.
326-
327-
Args:
328-
access_token_expires_in (str): The duration (in seconds) until the access token expires, or the expiry date in a specific format.
329-
token_expiry_date_format (str | None, optional): The format of the expiry date if provided. Defaults to None.
330-
331-
Returns:
332-
AirbyteDateTime: The calculated expiry date of the access token.
333-
"""
334-
if token_expiry_date_format:
335-
return ab_datetime_parse(access_token_expires_in)
336-
else:
337-
return ab_datetime_now() + timedelta(seconds=int(access_token_expires_in))
338-
339319
def get_access_token(self) -> str:
340320
"""Retrieve new access and refresh token if the access token has expired.
341321
The new refresh token is persisted with the set_refresh_token function
@@ -346,16 +326,13 @@ def get_access_token(self) -> str:
346326
new_access_token, access_token_expires_in, new_refresh_token = (
347327
self.refresh_access_token()
348328
)
349-
new_token_expiry_date: AirbyteDateTime = self.get_new_token_expiry_date(
350-
access_token_expires_in, self._token_expiry_date_format
351-
)
352329
self.access_token = new_access_token
353330
self.set_refresh_token(new_refresh_token)
354-
self.set_token_expiry_date(new_token_expiry_date)
331+
self.set_token_expiry_date(access_token_expires_in)
355332
self._emit_control_message()
356333
return self.access_token
357334

358-
def refresh_access_token(self) -> Tuple[str, str, str]: # type: ignore[override]
335+
def refresh_access_token(self) -> Tuple[str, AirbyteDateTime, str]: # type: ignore[override]
359336
"""
360337
Refreshes the access token by making a handled request and extracting the necessary token information.
361338

airbyte_cdk/test/entrypoint_wrapper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def records(self) -> List[AirbyteMessage]:
8282
def state_messages(self) -> List[AirbyteMessage]:
8383
return self._get_message_by_types([Type.STATE])
8484

85+
@property
86+
def connection_status_messages(self) -> List[AirbyteMessage]:
87+
return self._get_message_by_types([Type.CONNECTION_STATUS])
88+
8589
@property
8690
def most_recent_state(self) -> Any:
8791
state_messages = self._get_message_by_types([Type.STATE])
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) 2024 Airbyte, Inc., all rights reserved.
2+
'''FAST Airbyte Standard Tests
3+
4+
This module provides a set of base classes for declarative connector test suites.
5+
The goal of this module is to provide a robust and extensible framework for testing Airbyte
6+
connectors.
7+
8+
Example usage:
9+
10+
```python
11+
# `test_airbyte_standards.py`
12+
from airbyte_cdk.test import standard_tests
13+
14+
pytest_plugins = [
15+
"airbyte_cdk.test.standard_tests.pytest_hooks",
16+
]
17+
18+
19+
class TestSuiteSourcePokeAPI(standard_tests.DeclarativeSourceTestSuite):
20+
"""Test suite for the source."""
21+
```
22+
23+
Available test suites base classes:
24+
- `DeclarativeSourceTestSuite`: A test suite for declarative sources.
25+
- `SourceTestSuiteBase`: A test suite for sources.
26+
- `DestinationTestSuiteBase`: A test suite for destinations.
27+
28+
'''
29+
30+
from airbyte_cdk.test.standard_tests.connector_base import (
31+
ConnectorTestScenario,
32+
ConnectorTestSuiteBase,
33+
)
34+
from airbyte_cdk.test.standard_tests.declarative_sources import (
35+
DeclarativeSourceTestSuite,
36+
)
37+
from airbyte_cdk.test.standard_tests.destination_base import DestinationTestSuiteBase
38+
from airbyte_cdk.test.standard_tests.source_base import SourceTestSuiteBase
39+
40+
__all__ = [
41+
"ConnectorTestScenario",
42+
"ConnectorTestSuiteBase",
43+
"DeclarativeSourceTestSuite",
44+
"DestinationTestSuiteBase",
45+
"SourceTestSuiteBase",
46+
]
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
# Copyright (c) 2025 Airbyte, Inc., all rights reserved.
2+
"""Job runner for Airbyte Standard Tests."""
3+
4+
import logging
5+
import tempfile
6+
import uuid
7+
from dataclasses import asdict
8+
from pathlib import Path
9+
from typing import Any, Callable, Literal
10+
11+
import orjson
12+
from typing_extensions import Protocol, runtime_checkable
13+
14+
from airbyte_cdk.models import (
15+
ConfiguredAirbyteCatalog,
16+
Status,
17+
)
18+
from airbyte_cdk.test import entrypoint_wrapper
19+
from airbyte_cdk.test.standard_tests.models import (
20+
ConnectorTestScenario,
21+
)
22+
23+
24+
def _errors_to_str(
25+
entrypoint_output: entrypoint_wrapper.EntrypointOutput,
26+
) -> str:
27+
"""Convert errors from entrypoint output to a string."""
28+
if not entrypoint_output.errors:
29+
# If there are no errors, return an empty string.
30+
return ""
31+
32+
return "\n" + "\n".join(
33+
[
34+
str(error.trace.error).replace(
35+
"\\n",
36+
"\n",
37+
)
38+
for error in entrypoint_output.errors
39+
if error.trace
40+
],
41+
)
42+
43+
44+
@runtime_checkable
45+
class IConnector(Protocol):
46+
"""A connector that can be run in a test scenario.
47+
48+
Note: We currently use 'spec' to determine if we have a connector object.
49+
In the future, it would be preferred to leverage a 'launch' method instead,
50+
directly on the connector (which doesn't yet exist).
51+
"""
52+
53+
def spec(self, logger: logging.Logger) -> Any:
54+
"""Connectors should have a `spec` method."""
55+
56+
57+
def run_test_job(
58+
connector: IConnector | type[IConnector] | Callable[[], IConnector],
59+
verb: Literal["read", "check", "discover"],
60+
test_scenario: ConnectorTestScenario,
61+
*,
62+
catalog: ConfiguredAirbyteCatalog | dict[str, Any] | None = None,
63+
) -> entrypoint_wrapper.EntrypointOutput:
64+
"""Run a test scenario from provided CLI args and return the result."""
65+
if not connector:
66+
raise ValueError("Connector is required")
67+
68+
if catalog and isinstance(catalog, ConfiguredAirbyteCatalog):
69+
# Convert the catalog to a dict if it's already a ConfiguredAirbyteCatalog.
70+
catalog = asdict(catalog)
71+
72+
connector_obj: IConnector
73+
if isinstance(connector, type) or callable(connector):
74+
# If the connector is a class or a factory lambda, instantiate it.
75+
connector_obj = connector()
76+
elif isinstance(connector, IConnector):
77+
connector_obj = connector
78+
else:
79+
raise ValueError(
80+
f"Invalid connector input: {type(connector)}",
81+
)
82+
83+
args: list[str] = [verb]
84+
if test_scenario.config_path:
85+
args += ["--config", str(test_scenario.config_path)]
86+
elif test_scenario.config_dict:
87+
config_path = (
88+
Path(tempfile.gettempdir()) / "airbyte-test" / f"temp_config_{uuid.uuid4().hex}.json"
89+
)
90+
config_path.parent.mkdir(parents=True, exist_ok=True)
91+
config_path.write_text(orjson.dumps(test_scenario.config_dict).decode())
92+
args += ["--config", str(config_path)]
93+
94+
catalog_path: Path | None = None
95+
if verb not in ["discover", "check"]:
96+
# We need a catalog for read.
97+
if catalog:
98+
# Write the catalog to a temp json file and pass the path to the file as an argument.
99+
catalog_path = (
100+
Path(tempfile.gettempdir())
101+
/ "airbyte-test"
102+
/ f"temp_catalog_{uuid.uuid4().hex}.json"
103+
)
104+
catalog_path.parent.mkdir(parents=True, exist_ok=True)
105+
catalog_path.write_text(orjson.dumps(catalog).decode())
106+
elif test_scenario.configured_catalog_path:
107+
catalog_path = Path(test_scenario.configured_catalog_path)
108+
109+
if catalog_path:
110+
args += ["--catalog", str(catalog_path)]
111+
112+
# This is a bit of a hack because the source needs the catalog early.
113+
# Because it *also* can fail, we have to redundantly wrap it in a try/except block.
114+
115+
result: entrypoint_wrapper.EntrypointOutput = entrypoint_wrapper._run_command( # noqa: SLF001 # Non-public API
116+
source=connector_obj, # type: ignore [arg-type]
117+
args=args,
118+
expecting_exception=test_scenario.expect_exception,
119+
)
120+
if result.errors and not test_scenario.expect_exception:
121+
raise AssertionError(
122+
f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
123+
)
124+
125+
if verb == "check":
126+
# Check is expected to fail gracefully without an exception.
127+
# Instead, we assert that we have a CONNECTION_STATUS message with
128+
# a failure status.
129+
assert len(result.connection_status_messages) == 1, (
130+
"Expected exactly one CONNECTION_STATUS message. Got "
131+
f"{len(result.connection_status_messages)}:\n"
132+
+ "\n".join([str(msg) for msg in result.connection_status_messages])
133+
+ _errors_to_str(result)
134+
)
135+
if test_scenario.expect_exception:
136+
conn_status = result.connection_status_messages[0].connectionStatus
137+
assert conn_status, (
138+
"Expected CONNECTION_STATUS message to be present. Got: \n"
139+
+ "\n".join([str(msg) for msg in result.connection_status_messages])
140+
)
141+
assert conn_status.status == Status.FAILED, (
142+
"Expected CONNECTION_STATUS message to be FAILED. Got: \n"
143+
+ "\n".join([str(msg) for msg in result.connection_status_messages])
144+
)
145+
146+
return result
147+
148+
# For all other verbs, we assert check that an exception is raised (or not).
149+
if test_scenario.expect_exception:
150+
if not result.errors:
151+
raise AssertionError("Expected exception but got none.")
152+
153+
return result
154+
155+
assert not result.errors, (
156+
f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
157+
)
158+
159+
return result

0 commit comments

Comments
 (0)