Skip to content

Commit ea917b2

Browse files
committed
Improve CosRuntime
1 parent e93b7fc commit ea917b2

File tree

1 file changed

+45
-5
lines changed

1 file changed

+45
-5
lines changed

coagent/cos/runtime.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
from starlette.responses import Response, JSONResponse
66
from sse_starlette.sse import EventSourceResponse
77

8-
from coagent.core import Address, Channel, RawMessage, logger
8+
from coagent.core import (
9+
Address,
10+
Channel,
11+
DiscoveryQuery,
12+
DiscoveryReply,
13+
RawMessage,
14+
logger,
15+
)
916
from coagent.core.exceptions import BaseError
1017
from coagent.core.factory import DeleteAgent
1118
from coagent.core.types import Runtime
@@ -25,6 +32,27 @@ async def start(self):
2532
async def stop(self):
2633
await self._runtime.stop()
2734

35+
async def discover(self, request: Request):
36+
namespace: str = request.query_params.get("namespace", "")
37+
recursive: bool = request.query_params.get("recursive", "") == "true"
38+
inclusive: bool = request.query_params.get("inclusive", "") == "true"
39+
detailed: bool = request.query_params.get("detailed", "") == "true"
40+
41+
result: RawMessage = await self._runtime.channel.publish(
42+
Address(name="discovery"),
43+
DiscoveryQuery(
44+
namespace=namespace,
45+
recursive=recursive,
46+
inclusive=inclusive,
47+
detailed=detailed,
48+
).encode(),
49+
request=True,
50+
probe=False,
51+
)
52+
reply: DiscoveryReply = DiscoveryReply.decode(result)
53+
54+
return JSONResponse(reply.model_dump(mode="json"))
55+
2856
async def register(self, request: Request):
2957
data: dict = await request.json()
3058
name: str = data["name"]
@@ -99,9 +127,12 @@ async def event_stream() -> AsyncIterator[str]:
99127
async def publish(self, request: Request):
100128
data: dict = await request.json()
101129
try:
130+
msg = RawMessage.decode(data["msg"])
131+
await self._update_message_header_extensions(msg, request)
132+
102133
resp: RawMessage | None = await self._runtime.channel.publish(
103-
addr=Address.model_validate(data["addr"]),
104-
msg=RawMessage.model_validate(data["msg"]),
134+
addr=Address.decode(data["addr"]),
135+
msg=msg,
105136
request=data.get("request", False),
106137
reply=data.get("reply", ""),
107138
timeout=data.get("timeout", 0.5),
@@ -117,9 +148,12 @@ async def publish(self, request: Request):
117148

118149
async def publish_multi(self, request: Request):
119150
data: dict = await request.json()
151+
msg = RawMessage.decode(data["msg"])
152+
await self._update_message_header_extensions(msg, request)
153+
120154
msgs = self._runtime.channel.publish_multi(
121-
addr=Address.model_validate(data["addr"]),
122-
msg=RawMessage.model_validate(data["msg"]),
155+
addr=Address.decode(data["addr"]),
156+
msg=msg,
123157
probe=data.get("probe", True),
124158
)
125159

@@ -131,3 +165,9 @@ async def event_stream() -> AsyncIterator[str]:
131165
yield dict(event="error", data=exc.encode_json())
132166

133167
return EventSourceResponse(event_stream())
168+
169+
async def _update_message_header_extensions(
170+
self, msg: RawMessage, request: Request
171+
) -> None:
172+
"""Update the message header extensions according to the data from the request."""
173+
pass

0 commit comments

Comments
 (0)