Skip to content

Commit 344a495

Browse files
committed
tri-state ExpectedOutcome
1 parent 76aa0e7 commit 344a495

File tree

8 files changed

+113
-39
lines changed

8 files changed

+113
-39
lines changed

airbyte_cdk/test/entrypoint_wrapper.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Any, List, Mapping, Optional, Union
2525

2626
import orjson
27+
from langsmith import expect
2728
from pydantic import ValidationError as V2ValidationError
2829
from serpyco_rs import SchemaValidationError
2930

@@ -44,6 +45,7 @@
4445
Type,
4546
)
4647
from airbyte_cdk.sources import Source
48+
from airbyte_cdk.test.standard_tests.models.scenario import ExpectedOutcome
4749

4850

4951
class EntrypointOutput:
@@ -157,7 +159,9 @@ def is_not_in_logs(self, pattern: str) -> bool:
157159

158160

159161
def _run_command(
160-
source: Source, args: List[str], expecting_exception: bool = False
162+
source: Source,
163+
args: list[str],
164+
expected_outcome: ExpectedOutcome,
161165
) -> EntrypointOutput:
162166
log_capture_buffer = StringIO()
163167
stream_handler = logging.StreamHandler(log_capture_buffer)
@@ -175,35 +179,39 @@ def _run_command(
175179
for message in source_entrypoint.run(parsed_args):
176180
messages.append(message)
177181
except Exception as exception:
178-
if not expecting_exception:
182+
if expected_outcome.expect_success():
179183
print("Printing unexpected error from entrypoint_wrapper")
180184
print("".join(traceback.format_exception(None, exception, exception.__traceback__)))
185+
181186
uncaught_exception = exception
182187

183188
captured_logs = log_capture_buffer.getvalue().split("\n")[:-1]
184189

185190
parent_logger.removeHandler(stream_handler)
186191

187-
return EntrypointOutput(messages + captured_logs, uncaught_exception)
192+
return EntrypointOutput(messages + captured_logs, uncaught_exception=uncaught_exception)
188193

189194

190195
def discover(
191196
source: Source,
192197
config: Mapping[str, Any],
193-
expecting_exception: bool = False,
198+
*,
199+
expected_outcome: ExpectedOutcome = ExpectedOutcome.EXPECT_SUCCESS,
194200
) -> EntrypointOutput:
195201
"""
196202
config must be json serializable
197-
:param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
198-
provide expecting_exception=True so that the test output logs are cleaner
203+
:param expected_outcome: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
204+
provide `expected_outcome=ExpectedOutcome.EXPECT_FAILURE` so that the test output logs are cleaner
199205
"""
200206

201207
with tempfile.TemporaryDirectory() as tmp_directory:
202208
tmp_directory_path = Path(tmp_directory)
203209
config_file = make_file(tmp_directory_path / "config.json", config)
204210

205211
return _run_command(
206-
source, ["discover", "--config", config_file, "--debug"], expecting_exception
212+
source,
213+
["discover", "--config", config_file, "--debug"],
214+
expected_outcome=expected_outcome,
207215
)
208216

209217

@@ -212,13 +220,14 @@ def read(
212220
config: Mapping[str, Any],
213221
catalog: ConfiguredAirbyteCatalog,
214222
state: Optional[List[AirbyteStateMessage]] = None,
215-
expecting_exception: bool = False,
223+
*,
224+
expected_outcome: ExpectedOutcome = ExpectedOutcome.EXPECT_SUCCESS,
216225
) -> EntrypointOutput:
217226
"""
218227
config and state must be json serializable
219228
220-
:param expecting_exception: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
221-
provide expecting_exception=True so that the test output logs are cleaner
229+
:param expected_outcome: By default if there is an uncaught exception, the exception will be printed out. If this is expected, please
230+
provide `expected_outcome=ExpectedOutcome.EXPECT_FAILURE` so that the test output logs are cleaner.
222231
"""
223232
with tempfile.TemporaryDirectory() as tmp_directory:
224233
tmp_directory_path = Path(tmp_directory)
@@ -245,7 +254,7 @@ def read(
245254
]
246255
)
247256

248-
return _run_command(source, args, expecting_exception)
257+
return _run_command(source, args, expected_outcome=expected_outcome)
249258

250259

251260
def make_file(

airbyte_cdk/test/standard_tests/_job_runner.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ def run_test_job(
122122
result: entrypoint_wrapper.EntrypointOutput = entrypoint_wrapper._run_command( # noqa: SLF001 # Non-public API
123123
source=connector_obj, # type: ignore [arg-type]
124124
args=args,
125-
expecting_exception=test_scenario.expect_exception,
125+
expected_outcome=test_scenario.expected_outcome,
126126
)
127-
if result.errors and not test_scenario.expect_exception:
127+
if result.errors and test_scenario.expected_outcome.expect_success():
128128
raise AssertionError(
129129
f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
130130
)
@@ -139,7 +139,7 @@ def run_test_job(
139139
+ "\n".join([str(msg) for msg in result.connection_status_messages])
140140
+ _errors_to_str(result)
141141
)
142-
if test_scenario.expect_exception:
142+
if test_scenario.expected_outcome.expect_exception():
143143
conn_status = result.connection_status_messages[0].connectionStatus
144144
assert conn_status, (
145145
"Expected CONNECTION_STATUS message to be present. Got: \n"
@@ -153,14 +153,15 @@ def run_test_job(
153153
return result
154154

155155
# For all other verbs, we assert check that an exception is raised (or not).
156-
if test_scenario.expect_exception:
156+
if test_scenario.expected_outcome.expect_exception():
157157
if not result.errors:
158158
raise AssertionError("Expected exception but got none.")
159159

160160
return result
161161

162-
assert not result.errors, (
163-
f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
164-
)
162+
if test_scenario.expected_outcome.expect_success():
163+
assert not result.errors, (
164+
f"Expected no errors but got {len(result.errors)}: \n" + _errors_to_str(result)
165+
)
165166

166167
return result

airbyte_cdk/test/standard_tests/models/scenario.py

Lines changed: 58 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,47 @@
99

1010
from __future__ import annotations
1111

12-
from pathlib import Path
12+
from enum import Enum, auto
13+
from pathlib import Path # noqa: TC003 # Pydantic needs this (don't move to 'if typing' block)
1314
from typing import Any, Literal, cast
1415

1516
import yaml
1617
from pydantic import BaseModel, ConfigDict
1718

1819

20+
class ExpectedOutcome(Enum):
21+
"""Enum to represent the expected outcome of a test scenario.
22+
23+
Class supports comparisons to a boolean or None.
24+
"""
25+
26+
EXPECT_EXCEPTION = auto()
27+
EXPECT_SUCCESS = auto()
28+
ALLOW_ANY = auto()
29+
30+
@classmethod
31+
def from_status_str(cls, status: str | None) -> ExpectedOutcome:
32+
"""Convert a status string to an ExpectedOutcome."""
33+
if status is None:
34+
return ExpectedOutcome.ALLOW_ANY
35+
36+
try:
37+
return {
38+
"succeed": ExpectedOutcome.EXPECT_SUCCESS,
39+
"failed": ExpectedOutcome.EXPECT_EXCEPTION,
40+
}[status]
41+
except KeyError as ex:
42+
raise ValueError(f"Invalid status '{status}'. Expected 'succeed' or 'failed'.") from ex
43+
44+
def expect_exception(self) -> bool:
45+
"""Return whether the expectation is that an exception should be raised."""
46+
return self == ExpectedOutcome.EXPECT_EXCEPTION
47+
48+
def expect_success(self) -> bool:
49+
"""Return whether the expectation is that the test should succeed without exceptions."""
50+
return self == ExpectedOutcome.EXPECT_SUCCESS
51+
52+
1953
class ConnectorTestScenario(BaseModel):
2054
"""Acceptance test scenario, as a Pydantic model.
2155
@@ -82,8 +116,13 @@ def get_config_dict(
82116
raise ValueError("No config dictionary or path provided.")
83117

84118
@property
85-
def expect_exception(self) -> bool:
86-
return self.status and self.status == "failed" or False
119+
def expected_outcome(self) -> ExpectedOutcome:
120+
"""Whether the test scenario expects an exception to be raised.
121+
122+
Returns True if the scenario expects an exception, False if it does not,
123+
and None if there is no set expectation.
124+
"""
125+
return ExpectedOutcome.from_status_str(self.status)
87126

88127
@property
89128
def instance_name(self) -> str:
@@ -97,15 +136,11 @@ def __str__(self) -> str:
97136

98137
return f"'{hash(self)}' Test Scenario"
99138

100-
def without_expecting_failure(self) -> ConnectorTestScenario:
101-
"""Return a copy of the scenario that does not expect failure.
139+
def without_expected_outcome(self) -> ConnectorTestScenario:
140+
"""Return a copy of the scenario that does not expect failure or success.
102141
103-
This is useful when you need to run multiple steps and you
104-
want to defer failure expectation for one or more steps.
142+
This is useful when running multiple steps, to defer the expectations to a later step.
105143
"""
106-
if self.status != "failed":
107-
return self
108-
109144
return ConnectorTestScenario(
110145
**self.model_dump(exclude={"status"}),
111146
)
@@ -122,3 +157,16 @@ def with_expecting_failure(self) -> ConnectorTestScenario:
122157
**self.model_dump(exclude={"status"}),
123158
status="failed",
124159
)
160+
161+
def with_expecting_success(self) -> ConnectorTestScenario:
162+
"""Return a copy of the scenario that expects success.
163+
164+
This is useful when deriving new scenarios from existing ones.
165+
"""
166+
if self.status == "succeed":
167+
return self
168+
169+
return ConnectorTestScenario(
170+
**self.model_dump(exclude={"status"}),
171+
status="succeed",
172+
)

airbyte_cdk/test/standard_tests/source_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ def test_basic_read(
106106
self.create_connector(scenario),
107107
"discover",
108108
connector_root=self.get_connector_root_dir(),
109-
test_scenario=scenario.without_expecting_failure(),
109+
test_scenario=scenario.without_expected_outcome(),
110110
)
111-
if scenario.expect_exception and discover_result.errors:
111+
if scenario.expected_outcome.expect_exception() and discover_result.errors:
112112
# Failed as expected; we're done.
113113
return
114114

airbyte_cdk/test/utils/reading.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from airbyte_cdk.models import AirbyteStateMessage, ConfiguredAirbyteCatalog, SyncMode
77
from airbyte_cdk.test.catalog_builder import CatalogBuilder
88
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read
9+
from airbyte_cdk.test.standard_tests.models.scenario import ExpectedOutcome
910

1011

1112
def catalog(stream_name: str, sync_mode: SyncMode) -> ConfiguredAirbyteCatalog:
@@ -19,8 +20,8 @@ def read_records(
1920
stream_name: str,
2021
sync_mode: SyncMode,
2122
state: Optional[List[AirbyteStateMessage]] = None,
22-
expecting_exception: bool = False,
23+
expected_outcome: ExpectedOutcome = ExpectedOutcome.EXPECT_SUCCESS,
2324
) -> EntrypointOutput:
2425
"""Read records from a stream."""
2526
_catalog = catalog(stream_name, sync_mode)
26-
return read(source, config, _catalog, state, expecting_exception)
27+
return read(source, config, _catalog, state, expected_outcome=expected_outcome)

unit_tests/sources/declarative/file/test_file_stream.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from airbyte_cdk.test.entrypoint_wrapper import read as entrypoint_read
1818
from airbyte_cdk.test.mock_http import HttpMocker, HttpRequest, HttpResponse
1919
from airbyte_cdk.test.mock_http.response_builder import find_binary_response, find_template
20+
from airbyte_cdk.test.standard_tests.models.scenario import ExpectedOutcome
2021
from airbyte_cdk.test.state_builder import StateBuilder
2122

2223

@@ -53,8 +54,9 @@ def read(
5354
config_builder: ConfigBuilder,
5455
catalog: ConfiguredAirbyteCatalog,
5556
state_builder: Optional[StateBuilder] = None,
56-
expecting_exception: bool = False,
5757
yaml_file: Optional[str] = None,
58+
*,
59+
expected_outcome: ExpectedOutcome = ExpectedOutcome.EXPECT_SUCCESS,
5860
) -> EntrypointOutput:
5961
config = config_builder.build()
6062
state = state_builder.build() if state_builder else StateBuilder().build()
@@ -63,14 +65,19 @@ def read(
6365
config,
6466
catalog,
6567
state,
66-
expecting_exception,
68+
expected_outcome=expected_outcome,
6769
)
6870

6971

70-
def discover(config_builder: ConfigBuilder, expecting_exception: bool = False) -> EntrypointOutput:
72+
def discover(
73+
config_builder: ConfigBuilder,
74+
expected_outcome: ExpectedOutcome = ExpectedOutcome.EXPECT_SUCCESS,
75+
) -> EntrypointOutput:
7176
config = config_builder.build()
7277
return entrypoint_discover(
73-
_source(CatalogBuilder().build(), config), config, expecting_exception
78+
_source(CatalogBuilder().build(), config),
79+
config,
80+
expected_outcome=expected_outcome,
7481
)
7582

7683

unit_tests/sources/mock_server_tests/test_resumable_full_refresh.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
create_record_builder,
2828
create_response_builder,
2929
)
30+
from airbyte_cdk.test.standard_tests.models.scenario import ExpectedOutcome
3031
from airbyte_cdk.test.state_builder import StateBuilder
3132
from unit_tests.sources.mock_server_tests.mock_source_fixture import SourceFixture
3233
from unit_tests.sources.mock_server_tests.test_helpers import (
@@ -344,7 +345,7 @@ def test_resumable_full_refresh_failure(self, http_mocker):
344345
source,
345346
config=config,
346347
catalog=_create_catalog([("justice_songs", SyncMode.full_refresh, {})]),
347-
expecting_exception=True,
348+
expected_outcome=ExpectedOutcome.EXPECT_EXCEPTION,
348349
)
349350

350351
status_messages = actual_messages.get_stream_statuses("justice_songs")

unit_tests/test/test_entrypoint_wrapper.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from airbyte_cdk.sources.abstract_source import AbstractSource
3434
from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, discover, read
35+
from airbyte_cdk.test.standard_tests.models.scenario import ExpectedOutcome
3536
from airbyte_cdk.test.state_builder import StateBuilder
3637

3738

@@ -229,7 +230,7 @@ def test_given_unexpected_exception_when_discover_then_print(self, entrypoint, p
229230
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
230231
def test_given_expected_exception_when_discover_then_do_not_print(self, entrypoint, print_mock):
231232
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
232-
discover(self._a_source, _A_CONFIG, expecting_exception=True)
233+
discover(self._a_source, _A_CONFIG, expected_outcome=ExpectedOutcome.EXPECT_EXCEPTION)
233234
assert print_mock.call_count == 0
234235

235236
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
@@ -380,7 +381,13 @@ def test_given_unexpected_exception_when_read_then_print(self, entrypoint, print
380381
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")
381382
def test_given_expected_exception_when_read_then_do_not_print(self, entrypoint, print_mock):
382383
entrypoint.return_value.run.side_effect = ValueError("This error should not be printed")
383-
read(self._a_source, _A_CONFIG, _A_CATALOG, _A_STATE, expecting_exception=True)
384+
read(
385+
self._a_source,
386+
_A_CONFIG,
387+
_A_CATALOG,
388+
_A_STATE,
389+
expected_outcome=ExpectedOutcome.EXPECT_EXCEPTION,
390+
)
384391
assert print_mock.call_count == 0
385392

386393
@patch("airbyte_cdk.test.entrypoint_wrapper.AirbyteEntrypoint")

0 commit comments

Comments
 (0)