11import logging
2- from typing import Any , Dict , Iterable , List , Mapping , Optional , Tuple
2+ from typing import Any , List , Mapping , Optional , Tuple
33
44from airbyte_protocol_dataclasses .models import (
55 AirbyteCatalog ,
6- AirbyteConnectionStatus ,
7- AirbyteMessage ,
86 Status ,
9- TraceType ,
107)
11- from airbyte_protocol_dataclasses .models import Type as AirbyteMessageType
128from fastapi import HTTPException
139
1410from airbyte_cdk .connector_builder .models import StreamRead
1814 AirbyteStateMessage ,
1915 ConfiguredAirbyteCatalog ,
2016)
17+ from airbyte_cdk .models .airbyte_protocol_serializers import AirbyteMessageSerializer
2118from airbyte_cdk .sources .declarative .manifest_declarative_source import (
2219 ManifestDeclarativeSource ,
2320)
21+ from airbyte_cdk .test .entrypoint_wrapper import AirbyteEntrypointException , EntrypointOutput
2422
2523
2624class ManifestCommandProcessor :
@@ -63,20 +61,29 @@ def test_read(
6361 def check_connection (
6462 self ,
6563 config : Mapping [str , Any ],
66- ) -> Tuple [bool , str ]:
64+ ) -> Tuple [bool , Optional [ str ] ]:
6765 """
6866 Check the connection to the source.
6967 """
7068
7169 spec = self ._source .spec (self ._logger )
72- messages = AirbyteEntrypoint (source = self ._source ).check (spec , config )
73- messages_by_type = self ._get_messages_by_type (messages )
74- self ._raise_on_trace_message (messages_by_type )
75- connection_status = self ._get_connection_status (messages_by_type )
70+ entrypoint = AirbyteEntrypoint (source = self ._source )
71+ messages = entrypoint .check (spec , config )
72+ output = EntrypointOutput (
73+ messages = [AirbyteEntrypoint .airbyte_message_to_string (m ) for m in messages ],
74+ command = ["check" ],
75+ )
76+ self ._raise_on_trace_message (output )
77+
78+ status_messages = output .connection_status_messages
79+ if not status_messages or status_messages [- 1 ].connectionStatus is None :
80+ return False , "Connection check did not return a status message"
7681
77- if connection_status :
78- return connection_status .status == Status .SUCCEEDED , connection_status .message
79- return False , "Connection check failed"
82+ connection_status = status_messages [- 1 ].connectionStatus
83+ return (
84+ connection_status .status == Status .SUCCEEDED ,
85+ connection_status .message ,
86+ )
8087
8188 def discover (
8289 self ,
@@ -86,66 +93,29 @@ def discover(
8693 Discover the catalog from the source.
8794 """
8895 spec = self ._source .spec (self ._logger )
89- messages = AirbyteEntrypoint (source = self ._source ).discover (spec , config )
90- messages_by_type = self ._get_messages_by_type (messages )
91- self ._raise_on_trace_message (messages_by_type )
92- return self ._get_catalog (messages_by_type )
93-
94- def _get_messages_by_type (
95- self ,
96- messages : Iterable [AirbyteMessage ],
97- ) -> Dict [str , List [AirbyteMessage ]]:
98- """
99- Group messages by type.
100- """
101- grouped : Dict [str , List [AirbyteMessage ]] = {}
102- for message in messages :
103- msg_type = message .type
104- if msg_type not in grouped :
105- grouped [msg_type ] = []
106- grouped [msg_type ].append (message )
107- return grouped
108-
109- def _get_connection_status (
110- self ,
111- messages_by_type : Mapping [str , List [AirbyteMessage ]],
112- ) -> Optional [AirbyteConnectionStatus ]:
113- """
114- Get the connection status from the messages.
115- """
116- messages = messages_by_type .get (AirbyteMessageType .CONNECTION_STATUS )
117- return messages [- 1 ].connectionStatus if messages else None
96+ entrypoint = AirbyteEntrypoint (source = self ._source )
97+ messages = entrypoint .discover (spec , config )
98+ output = EntrypointOutput (
99+ messages = [AirbyteEntrypoint .airbyte_message_to_string (m ) for m in messages ],
100+ command = ["discover" ],
101+ )
102+ self ._raise_on_trace_message (output )
118103
119- def _get_catalog (
120- self ,
121- messages_by_type : Mapping [str , List [AirbyteMessage ]],
122- ) -> Optional [AirbyteCatalog ]:
123- """
124- Get the catalog from the messages.
125- """
126- messages = messages_by_type .get (AirbyteMessageType .CATALOG )
127- return messages [- 1 ].catalog if messages else None
104+ try :
105+ catalog_message = output .catalog
106+ return catalog_message .catalog
107+ except ValueError :
108+ # No catalog message found
109+ return None
128110
129111 def _raise_on_trace_message (
130112 self ,
131- messages_by_type : Mapping [ str , List [ AirbyteMessage ]] ,
113+ output : EntrypointOutput ,
132114 ) -> None :
133115 """
134116 Raise an exception if a trace message is found.
135117 """
136- messages = [
137- message
138- for message in messages_by_type .get (AirbyteMessageType .TRACE , [])
139- if message .trace .type == TraceType .ERROR
140- ]
141- if messages :
142- error_message = messages [- 1 ].trace .error .message
143- self ._logger .warning (
144- "Error response from CDK: %s\n %s" ,
145- error_message ,
146- messages [- 1 ].trace .error .stack_trace ,
147- )
148- raise HTTPException (
149- status_code = 422 ,
150- detail = f"AirbyteTraceMessage response from CDK: { error_message } " ,
151- )
118+ try :
119+ output .raise_if_errors ()
120+ except AirbyteEntrypointException as e :
121+ raise HTTPException (status_code = 422 , detail = e .message )
0 commit comments