Skip to content

Commit e27257a

Browse files
committed
Fix CoS
1 parent cd308aa commit e27257a

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

coagent/cos/app.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ def starlette(self) -> Starlette:
4444
self.runtime.publish,
4545
methods=["POST"],
4646
),
47-
Route(
48-
"/runtime/channel/publish_multi",
49-
self.runtime.publish_multi,
50-
methods=["POST"],
51-
),
5247
]
5348
return Starlette(
5449
debug=True,

coagent/cos/runtime.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -128,19 +128,44 @@ async def event_stream() -> AsyncIterator[str]:
128128

129129
async def publish(self, request: Request):
130130
data: dict = await request.json()
131-
addr = Address.decode(data["addr"])
132-
msg = RawMessage.decode(data["msg"])
133131

134-
try:
135-
await self._update_message_header_extensions(msg, request)
132+
addr: Address = Address.decode(data["addr"])
133+
msg: RawMessage = RawMessage.decode(data["msg"])
134+
stream: bool = data.get("stream", False)
135+
probe: bool = data.get("probe", True)
136136

137-
resp: RawMessage | None = await self._runtime.channel.publish(
138-
addr=addr,
139-
msg=msg,
137+
await self._update_message_header_extensions(msg, request)
138+
139+
if stream:
140+
return await self._publish_stream(addr, msg, probe=probe)
141+
else:
142+
return await self._publish(
143+
addr,
144+
msg,
140145
request=data.get("request", False),
141146
reply=data.get("reply", ""),
142147
timeout=data.get("timeout", 0.5),
143-
probe=data.get("probe", True),
148+
probe=probe,
149+
)
150+
151+
async def _publish(
152+
self,
153+
addr: Address,
154+
msg: RawMessage,
155+
request: bool,
156+
reply: str,
157+
timeout: float,
158+
probe: bool,
159+
):
160+
try:
161+
resp: RawMessage | None = await self._runtime.channel.publish(
162+
addr=addr,
163+
msg=msg,
164+
stream=False,
165+
request=request,
166+
reply=reply,
167+
timeout=timeout,
168+
probe=probe,
144169
)
145170
except BaseError as exc:
146171
return JSONResponse(exc.encode(mode="json"), status_code=404)
@@ -155,17 +180,12 @@ async def publish(self, request: Request):
155180
else:
156181
return JSONResponse(resp.encode(mode="json"))
157182

158-
async def publish_multi(self, request: Request):
159-
data: dict = await request.json()
160-
msg = RawMessage.decode(data["msg"])
161-
await self._update_message_header_extensions(msg, request)
162-
163-
addr = Address.decode(data["addr"])
164-
msgs = await self._runtime.channel.publish(
183+
async def _publish_stream(self, addr: Address, msg: RawMessage, probe: bool):
184+
msgs: AsyncIterator[RawMessage] = await self._runtime.channel.publish(
165185
addr=addr,
166186
msg=msg,
167187
stream=True,
168-
probe=data.get("probe", True),
188+
probe=probe,
169189
)
170190

171191
async def event_stream() -> AsyncIterator[str]:

0 commit comments

Comments
 (0)