44
55from pydantic import BaseModel , ValidationError
66
7- from .exceptions import MessageDecodeError , InternalError
7+ from .exceptions import MessageDecodeError , InternalError , StreamError
88from .logger import logger
99from .messages import (
1010 Cancel ,
1717 SetReplyAgent ,
1818 StopIteration ,
1919)
20- from .types import Address , Agent , Channel , RawMessage , State , Subscription
20+ from .types import Address , Agent , Channel , RawMessage , Reply , State , Subscription
2121
2222
2323class Context :
@@ -122,8 +122,8 @@ def __init__(self, timeout: float = 60):
122122 # this would result in a lot of messages.
123123 self ._lock : asyncio .Lock = asyncio .Lock ()
124124
125- # Normally reply_address is set by an orchestration agent by sending a `SetReplyAgent` message.
126- self .reply_address : Address | None = None
125+ # Normally `reply` is set by an orchestration agent by sending a `SetReplyAgent` message.
126+ self .reply : Reply | None = None
127127
128128 handlers , message_types = self .__collect_handlers ()
129129 # A list of handlers that are registered to handle messages.
@@ -269,7 +269,7 @@ async def _handle_data(self) -> None:
269269 await self .stopped ()
270270
271271 case SetReplyAgent ():
272- self .reply_address = msg .address
272+ self .reply = msg .reply_info
273273
274274 case ProbeAgent () | Empty ():
275275 # Do not handle probes and empty messages.
@@ -282,18 +282,41 @@ async def _handle_data_custom(self, msg: Message, ctx: Context) -> None:
282282 """Handle user-defined DATA messages."""
283283 h : Handler = self .__get_handler (msg )
284284 result = h (self , msg , ctx )
285+ if not is_async_iterator (result ):
286+ result = await result or Empty ()
287+ await self .__send_reply (msg .reply , result )
288+
289+ async def __send_reply (
290+ self , in_msg_reply : Reply , result : Message | AsyncIterator [Message ]
291+ ) -> bool :
292+ reply = self .reply or in_msg_reply
293+ if not reply :
294+ return False
295+
296+ # Reply to the sender if asked.
297+ await self .send_reply (reply , result )
298+ return True
285299
286- async def pub (x : Message ):
287- await self .__send_reply (msg .reply , x )
300+ async def send_reply (
301+ self ,
302+ to : Reply ,
303+ result : Message | AsyncIterator [Message ],
304+ ) -> None :
305+ async def pub (msg : Message ):
306+ await self .channel .publish (to .address , msg .encode ())
288307
289308 async def pub_exc (exc : BaseException ):
290309 err = InternalError .from_exception (exc )
291310 await pub (err .encode_message ())
292311
293- if is_async_iterator ( result ) :
312+ if to . stream :
294313 try :
295- async for x in result :
296- await pub (x )
314+ if is_async_iterator (result ):
315+ async for msg in result :
316+ await pub (msg )
317+ else :
318+ msg = result
319+ await pub (msg )
297320 except asyncio .CancelledError as exc :
298321 await pub_exc (exc )
299322 raise
@@ -304,23 +327,26 @@ async def pub_exc(exc: BaseException):
304327 await pub (StopIteration ())
305328 else :
306329 try :
307- x = await result or Empty ()
308- await pub (x )
330+ if is_async_iterator (result ):
331+ accumulated : RawMessage | None = None
332+ async for msg in result :
333+ if not accumulated :
334+ accumulated = msg
335+ else :
336+ try :
337+ accumulated += msg
338+ except TypeError :
339+ await pub_exc (StreamError ("Streaming mode is required" ))
340+ await pub (accumulated )
341+ else :
342+ msg = result
343+ await pub (msg )
309344 except asyncio .CancelledError as exc :
310345 await pub_exc (exc )
311346 raise
312347 except Exception as exc :
313348 await pub_exc (exc )
314349
315- async def __send_reply (self , in_msg_reply : Address , out_msg : Message ) -> bool :
316- reply_address = self .reply_address or in_msg_reply
317- if not reply_address :
318- return False
319-
320- # Reply to the sending agent if asked.
321- await self .channel .publish (reply_address , out_msg .encode ())
322- return True
323-
324350 def __get_handler (self , msg : Message ) -> Handler | None :
325351 msg_type : Type [Any ] = type (msg )
326352
0 commit comments