1919import re
2020import tempfile
2121import traceback
22+ from collections import deque
23+ from collections .abc import Generator , Mapping
2224from io import StringIO
2325from pathlib import Path
24- from typing import Any , List , Mapping , Optional , Union
26+ from typing import Any , List , Literal , Optional , Union , final , overload
2527
2628import orjson
2729from pydantic import ValidationError as V2ValidationError
4345 TraceType ,
4446 Type ,
4547)
48+ from airbyte_cdk .models .airbyte_protocol import AirbyteMessage , AirbyteStreamState
4649from airbyte_cdk .sources import Source
4750from airbyte_cdk .test .models .scenario import ExpectedOutcome
4851
4952
5053class EntrypointOutput :
51- def __init__ (self , messages : List [str ], uncaught_exception : Optional [BaseException ] = None ):
52- try :
53- self ._messages = [self ._parse_message (message ) for message in messages ]
54- except V2ValidationError as exception :
55- raise ValueError ("All messages are expected to be AirbyteMessage" ) from exception
54+ """A class to encapsulate the output of an Airbyte connector's execution.
55+
56+ This class can be initialized with a list of messages or a file containing messages.
57+ It provides methods to access different types of messages produced during the execution
58+ of an Airbyte connector, including both successful messages and error messages.
59+
60+ When working with records and state messages, it provides both a list and an iterator
61+ implementation. Lists are easier to work with, but generators are better suited to handle
62+ large volumes of messages without overflowing the available memory.
63+ """
64+
65+ def __init__ (
66+ self ,
67+ messages : list [str ] | None = None ,
68+ uncaught_exception : Optional [BaseException ] = None ,
69+ * ,
70+ message_file : Path | None = None ,
71+ ) -> None :
72+ if messages is None and message_file is None :
73+ raise ValueError ("Either messages or message_file must be provided" )
74+ if messages is not None and message_file is not None :
75+ raise ValueError ("Only one of messages or message_file can be provided" )
76+
77+ self ._messages : list [AirbyteMessage ] | None = []
78+ self ._message_file : Path | None = message_file
79+ if messages :
80+ try :
81+ self ._messages = [self ._parse_message (message ) for message in messages ]
82+ except V2ValidationError as exception :
83+ raise ValueError ("All messages are expected to be AirbyteMessage" ) from exception
5684
5785 if uncaught_exception :
86+ if self ._messages is None :
87+ self ._messages = []
88+
5889 self ._messages .append (
5990 assemble_uncaught_exception (
6091 type (uncaught_exception ), uncaught_exception
@@ -72,13 +103,40 @@ def _parse_message(message: str) -> AirbyteMessage:
72103 )
73104
74105 @property
75- def records_and_state_messages (self ) -> List [AirbyteMessage ]:
76- return self ._get_message_by_types ([Type .RECORD , Type .STATE ])
106+ def records_and_state_messages (
107+ self ,
108+ ) -> list [AirbyteMessage ]:
109+ return self ._get_message_by_types (
110+ message_types = [Type .RECORD , Type .STATE ],
111+ safe_iterator = False ,
112+ )
113+
114+ def records_and_state_messages_iterator (
115+ self ,
116+ ) -> Generator [AirbyteMessage , None , None ]:
117+ """Returns a generator that yields record and state messages one by one.
118+
119+ Use this instead of `records_and_state_messages` when the volume of messages could be large
120+ enough to overload available memory.
121+ """
122+ return self ._get_message_by_types (
123+ message_types = [Type .RECORD , Type .STATE ],
124+ safe_iterator = True ,
125+ )
77126
78127 @property
79128 def records (self ) -> List [AirbyteMessage ]:
80129 return self ._get_message_by_types ([Type .RECORD ])
81130
131+ @property
132+ def records_iterator (self ) -> Generator [AirbyteMessage , None , None ]:
133+ """Returns a generator that yields record messages one by one.
134+
135+ Use this instead of `records` when the volume of records could be large
136+ enough to overload available memory.
137+ """
138+ return self ._get_message_by_types ([Type .RECORD ], safe_iterator = True )
139+
82140 @property
83141 def state_messages (self ) -> List [AirbyteMessage ]:
84142 return self ._get_message_by_types ([Type .STATE ])
@@ -92,11 +150,21 @@ def connection_status_messages(self) -> List[AirbyteMessage]:
92150 return self ._get_message_by_types ([Type .CONNECTION_STATUS ])
93151
94152 @property
95- def most_recent_state (self ) -> Any :
96- state_messages = self ._get_message_by_types ([Type .STATE ])
97- if not state_messages :
98- raise ValueError ("Can't provide most recent state as there are no state messages" )
99- return state_messages [- 1 ].state .stream # type: ignore[union-attr] # state has `stream`
153+ def most_recent_state (self ) -> AirbyteStreamState | None :
154+ state_message_iterator = self ._get_message_by_types (
155+ [Type .STATE ],
156+ safe_iterator = True ,
157+ )
158+ # Use a deque with maxlen=1 to efficiently get the last state message
159+ double_ended_queue = deque (state_message_iterator , maxlen = 1 )
160+ try :
161+ final_state_message : AirbyteMessage = double_ended_queue .pop ()
162+ except IndexError :
163+ raise ValueError (
164+ "Can't provide most recent state as there are no state messages."
165+ ) from None
166+
167+ return final_state_message .state .stream # type: ignore[union-attr] # state has `stream`
100168
101169 @property
102170 def logs (self ) -> List [AirbyteMessage ]:
@@ -131,13 +199,80 @@ def get_stream_statuses(self, stream_name: str) -> List[AirbyteStreamStatus]:
131199 )
132200 return list (status_messages )
133201
134- def _get_message_by_types (self , message_types : List [Type ]) -> List [AirbyteMessage ]:
135- return [message for message in self ._messages if message .type in message_types ]
202+ def _read_all_messages (self ) -> Generator [AirbyteMessage , None , None ]:
203+ """Creates a generator which yields messages one by one.
204+
205+ This will iterate over all messages in the output file (if provided) or the messages
206+ provided during initialization. File results are provided first, followed by any
207+ messages that were passed in directly.
208+ """
209+ if self ._message_file :
210+ try :
211+ with open (self ._message_file , "r" , encoding = "utf-8" ) as file :
212+ for line in file :
213+ if not line .strip ():
214+ # Skip empty lines
215+ continue
216+
217+ yield self ._parse_message (line .strip ())
218+ except FileNotFoundError :
219+ raise ValueError (f"Message file { self ._message_file } not found" )
220+
221+ if self ._messages is not None :
222+ yield from self ._messages
223+
224+ # Overloads to provide proper type hints for different usages of `_get_message_by_types`.
225+
226+ @overload
227+ def _get_message_by_types (
228+ self ,
229+ message_types : list [Type ],
230+ ) -> list [AirbyteMessage ]: ...
231+
232+ @overload
233+ def _get_message_by_types (
234+ self ,
235+ message_types : list [Type ],
236+ * ,
237+ safe_iterator : Literal [False ],
238+ ) -> list [AirbyteMessage ]: ...
239+
240+ @overload
241+ def _get_message_by_types (
242+ self ,
243+ message_types : list [Type ],
244+ * ,
245+ safe_iterator : Literal [True ],
246+ ) -> Generator [AirbyteMessage , None , None ]: ...
247+
248+ def _get_message_by_types (
249+ self ,
250+ message_types : list [Type ],
251+ * ,
252+ safe_iterator : bool = True ,
253+ ) -> list [AirbyteMessage ] | Generator [AirbyteMessage , None , None ]:
254+ """Get messages of specific types.
255+
256+ If `safe_iterator` is True, returns a generator that yields messages one by one.
257+ If `safe_iterator` is False, returns a list of messages.
258+
259+ Use `safe_iterator=True` when the volume of messages could overload the available
260+ memory.
261+ """
262+ message_generator = self ._read_all_messages ()
263+
264+ if safe_iterator :
265+ return (message for message in message_generator if message .type in message_types )
266+
267+ return [message for message in message_generator if message .type in message_types ]
136268
137269 def _get_trace_message_by_trace_type (self , trace_type : TraceType ) -> List [AirbyteMessage ]:
138270 return [
139271 message
140- for message in self ._get_message_by_types ([Type .TRACE ])
272+ for message in self ._get_message_by_types (
273+ [Type .TRACE ],
274+ safe_iterator = True ,
275+ )
141276 if message .trace .type == trace_type # type: ignore[union-attr] # trace has `type`
142277 ]
143278
0 commit comments