Skip to content

Commit 24c2184

Browse files
committed
move to EntrypointOutput
1 parent 0ec16ed commit 24c2184

File tree

2 files changed

+99
-168
lines changed

2 files changed

+99
-168
lines changed
Lines changed: 38 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
import logging
2-
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
2+
from typing import Any, List, Mapping, Optional, Tuple
33

44
from airbyte_protocol_dataclasses.models import (
55
AirbyteCatalog,
6-
AirbyteConnectionStatus,
7-
AirbyteMessage,
86
Status,
9-
TraceType,
107
)
11-
from airbyte_protocol_dataclasses.models import Type as AirbyteMessageType
128
from fastapi import HTTPException
139

1410
from airbyte_cdk.connector_builder.models import StreamRead
@@ -18,9 +14,11 @@
1814
AirbyteStateMessage,
1915
ConfiguredAirbyteCatalog,
2016
)
17+
from airbyte_cdk.models.airbyte_protocol_serializers import AirbyteMessageSerializer
2118
from airbyte_cdk.sources.declarative.manifest_declarative_source import (
2219
ManifestDeclarativeSource,
2320
)
21+
from airbyte_cdk.test.entrypoint_wrapper import AirbyteEntrypointException, EntrypointOutput
2422

2523

2624
class ManifestCommandProcessor:
@@ -63,20 +61,29 @@ def test_read(
6361
def check_connection(
6462
self,
6563
config: Mapping[str, Any],
66-
) -> Tuple[bool, str]:
64+
) -> Tuple[bool, Optional[str]]:
6765
"""
6866
Check the connection to the source.
6967
"""
7068

7169
spec = self._source.spec(self._logger)
72-
messages = AirbyteEntrypoint(source=self._source).check(spec, config)
73-
messages_by_type = self._get_messages_by_type(messages)
74-
self._raise_on_trace_message(messages_by_type)
75-
connection_status = self._get_connection_status(messages_by_type)
70+
entrypoint = AirbyteEntrypoint(source=self._source)
71+
messages = entrypoint.check(spec, config)
72+
output = EntrypointOutput(
73+
messages=[AirbyteEntrypoint.airbyte_message_to_string(m) for m in messages],
74+
command=["check"],
75+
)
76+
self._raise_on_trace_message(output)
77+
78+
status_messages = output.connection_status_messages
79+
if not status_messages or status_messages[-1].connectionStatus is None:
80+
return False, "Connection check did not return a status message"
7681

77-
if connection_status:
78-
return connection_status.status == Status.SUCCEEDED, connection_status.message
79-
return False, "Connection check failed"
82+
connection_status = status_messages[-1].connectionStatus
83+
return (
84+
connection_status.status == Status.SUCCEEDED,
85+
connection_status.message,
86+
)
8087

8188
def discover(
8289
self,
@@ -86,66 +93,29 @@ def discover(
8693
Discover the catalog from the source.
8794
"""
8895
spec = self._source.spec(self._logger)
89-
messages = AirbyteEntrypoint(source=self._source).discover(spec, config)
90-
messages_by_type = self._get_messages_by_type(messages)
91-
self._raise_on_trace_message(messages_by_type)
92-
return self._get_catalog(messages_by_type)
93-
94-
def _get_messages_by_type(
95-
self,
96-
messages: Iterable[AirbyteMessage],
97-
) -> Dict[str, List[AirbyteMessage]]:
98-
"""
99-
Group messages by type.
100-
"""
101-
grouped: Dict[str, List[AirbyteMessage]] = {}
102-
for message in messages:
103-
msg_type = message.type
104-
if msg_type not in grouped:
105-
grouped[msg_type] = []
106-
grouped[msg_type].append(message)
107-
return grouped
108-
109-
def _get_connection_status(
110-
self,
111-
messages_by_type: Mapping[str, List[AirbyteMessage]],
112-
) -> Optional[AirbyteConnectionStatus]:
113-
"""
114-
Get the connection status from the messages.
115-
"""
116-
messages = messages_by_type.get(AirbyteMessageType.CONNECTION_STATUS)
117-
return messages[-1].connectionStatus if messages else None
96+
entrypoint = AirbyteEntrypoint(source=self._source)
97+
messages = entrypoint.discover(spec, config)
98+
output = EntrypointOutput(
99+
messages=[AirbyteEntrypoint.airbyte_message_to_string(m) for m in messages],
100+
command=["discover"],
101+
)
102+
self._raise_on_trace_message(output)
118103

119-
def _get_catalog(
120-
self,
121-
messages_by_type: Mapping[str, List[AirbyteMessage]],
122-
) -> Optional[AirbyteCatalog]:
123-
"""
124-
Get the catalog from the messages.
125-
"""
126-
messages = messages_by_type.get(AirbyteMessageType.CATALOG)
127-
return messages[-1].catalog if messages else None
104+
try:
105+
catalog_message = output.catalog
106+
return catalog_message.catalog
107+
except ValueError:
108+
# No catalog message found
109+
return None
128110

129111
def _raise_on_trace_message(
130112
self,
131-
messages_by_type: Mapping[str, List[AirbyteMessage]],
113+
output: EntrypointOutput,
132114
) -> None:
133115
"""
134116
Raise an exception if a trace message is found.
135117
"""
136-
messages = [
137-
message
138-
for message in messages_by_type.get(AirbyteMessageType.TRACE, [])
139-
if message.trace.type == TraceType.ERROR
140-
]
141-
if messages:
142-
error_message = messages[-1].trace.error.message
143-
self._logger.warning(
144-
"Error response from CDK: %s\n%s",
145-
error_message,
146-
messages[-1].trace.error.stack_trace,
147-
)
148-
raise HTTPException(
149-
status_code=422,
150-
detail=f"AirbyteTraceMessage response from CDK: {error_message}",
151-
)
118+
try:
119+
output.raise_if_errors()
120+
except AirbyteEntrypointException as e:
121+
raise HTTPException(status_code=422, detail=e.message)

0 commit comments

Comments
 (0)