55from starlette .responses import Response , JSONResponse
66from sse_starlette .sse import EventSourceResponse
77
8- from coagent .core import Address , Channel , RawMessage , logger
8+ from coagent .core import (
9+ Address ,
10+ Channel ,
11+ DiscoveryQuery ,
12+ DiscoveryReply ,
13+ RawMessage ,
14+ logger ,
15+ )
916from coagent .core .exceptions import BaseError
1017from coagent .core .factory import DeleteAgent
1118from coagent .core .types import Runtime
@@ -25,6 +32,27 @@ async def start(self):
2532 async def stop (self ):
2633 await self ._runtime .stop ()
2734
35+ async def discover (self , request : Request ):
36+ namespace : str = request .query_params .get ("namespace" , "" )
37+ recursive : bool = request .query_params .get ("recursive" , "" ) == "true"
38+ inclusive : bool = request .query_params .get ("inclusive" , "" ) == "true"
39+ detailed : bool = request .query_params .get ("detailed" , "" ) == "true"
40+
41+ result : RawMessage = await self ._runtime .channel .publish (
42+ Address (name = "discovery" ),
43+ DiscoveryQuery (
44+ namespace = namespace ,
45+ recursive = recursive ,
46+ inclusive = inclusive ,
47+ detailed = detailed ,
48+ ).encode (),
49+ request = True ,
50+ probe = False ,
51+ )
52+ reply : DiscoveryReply = DiscoveryReply .decode (result )
53+
54+ return JSONResponse (reply .model_dump (mode = "json" ))
55+
2856 async def register (self , request : Request ):
2957 data : dict = await request .json ()
3058 name : str = data ["name" ]
@@ -99,9 +127,12 @@ async def event_stream() -> AsyncIterator[str]:
99127 async def publish (self , request : Request ):
100128 data : dict = await request .json ()
101129 try :
130+ msg = RawMessage .decode (data ["msg" ])
131+ await self ._update_message_header_extensions (msg , request )
132+
102133 resp : RawMessage | None = await self ._runtime .channel .publish (
103- addr = Address .model_validate (data ["addr" ]),
104- msg = RawMessage . model_validate ( data [ " msg" ]) ,
134+ addr = Address .decode (data ["addr" ]),
135+ msg = msg ,
105136 request = data .get ("request" , False ),
106137 reply = data .get ("reply" , "" ),
107138 timeout = data .get ("timeout" , 0.5 ),
@@ -117,9 +148,12 @@ async def publish(self, request: Request):
117148
118149 async def publish_multi (self , request : Request ):
119150 data : dict = await request .json ()
151+ msg = RawMessage .decode (data ["msg" ])
152+ await self ._update_message_header_extensions (msg , request )
153+
120154 msgs = self ._runtime .channel .publish_multi (
121- addr = Address .model_validate (data ["addr" ]),
122- msg = RawMessage . model_validate ( data [ " msg" ]) ,
155+ addr = Address .decode (data ["addr" ]),
156+ msg = msg ,
123157 probe = data .get ("probe" , True ),
124158 )
125159
@@ -131,3 +165,9 @@ async def event_stream() -> AsyncIterator[str]:
131165 yield dict (event = "error" , data = exc .encode_json ())
132166
133167 return EventSourceResponse (event_stream ())
168+
169+ async def _update_message_header_extensions (
170+ self , msg : RawMessage , request : Request
171+ ) -> None :
172+ """Update the message header extensions according to the data from the request."""
173+ pass
0 commit comments