Skip to content

Commit cd308aa

Browse files
committed
Improve HTTPChannel
1 parent 47c7542 commit cd308aa

File tree

4 files changed

+41
-49
lines changed

4 files changed

+41
-49
lines changed

coagent/runtimes/http_runtime.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class HTTPChannel(BaseChannel):
4444
"""An HTTP-based channel.
4545
4646
_publish: POST /publish
47-
_publish_stream: POST /publish_multi
47+
_publish_stream: POST /publish stream=True
4848
subscribe: POST /subscribe
4949
new_reply_topic: POST /reply-topics
5050
"""
@@ -68,17 +68,17 @@ async def _publish(
6868
self,
6969
addr: Address,
7070
msg: RawMessage,
71-
request: bool = False,
7271
stream: bool = False,
72+
request: bool = False,
7373
reply: str = "",
7474
timeout: float = 5.0,
7575
probe: bool = True,
7676
) -> RawMessage | None:
7777
data = dict(
7878
addr=addr.encode(mode="json"),
7979
msg=msg.encode(mode="json"),
80-
request=request,
8180
stream=stream,
81+
request=request,
8282
reply=reply,
8383
timeout=timeout,
8484
probe=probe,
@@ -108,13 +108,14 @@ async def _publish_stream(
108108
data = dict(
109109
addr=addr.encode(mode="json"),
110110
msg=msg.encode(mode="json"),
111+
stream=True,
111112
probe=probe,
112113
)
113114
headers = {"Authorization": self._auth} if self._auth else None
114115

115116
queue: QueueSubscriptionIterator = QueueSubscriptionIterator()
116117
sub: HTTPChannelSubscription = HTTPChannelSubscription(
117-
f"{self._server}/publish_multi", data, headers, queue.receive
118+
f"{self._server}/publish", data, headers, queue.receive
118119
)
119120
await sub.subscribe()
120121

@@ -268,37 +269,22 @@ async def publish(
268269
self,
269270
addr: Address,
270271
msg: RawMessage,
271-
request: bool = False,
272272
stream: bool = False,
273+
request: bool = False,
273274
reply: str = "",
274275
timeout: float = 5.0,
275276
probe: bool = True,
276277
) -> RawMessage | None:
277278
return await self._channel.publish(
278279
addr,
279280
msg,
280-
request=request,
281281
stream=stream,
282+
request=request,
282283
reply=reply,
283284
timeout=timeout,
284285
probe=probe,
285286
)
286287

287-
async def publish_multi(
288-
self,
289-
addr: Address,
290-
msg: RawMessage,
291-
probe: bool = True,
292-
) -> AsyncIterator[RawMessage]:
293-
msgs = await self._channel.publish(
294-
addr,
295-
msg,
296-
stream=True,
297-
probe=probe,
298-
)
299-
async for msg in msgs:
300-
yield msg
301-
302288
async def subscribe(
303289
self,
304290
addr: Address,

coagent/runtimes/local_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ async def _publish(
6969
self,
7070
addr: Address,
7171
msg: RawMessage,
72-
request: bool = False,
7372
stream: bool = False,
73+
request: bool = False,
7474
reply: str = "",
7575
timeout: float = 0.5,
7676
probe: bool = True,

coagent/runtimes/nats_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ async def _publish(
7777
self,
7878
addr: Address,
7979
msg: RawMessage,
80-
request: bool = False,
8180
stream: bool = False,
81+
request: bool = False,
8282
reply: str = "",
8383
timeout: float = 0.5,
8484
probe: bool = True,

examples/ping-pong/http_runtime_server.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,40 @@ async def shutdown():
2727

2828
async def publish(request):
2929
data: dict = await request.json()
30+
31+
addr: Address = Address.model_validate(data["addr"])
32+
msg: RawMessage = RawMessage.model_validate(data["msg"])
33+
stream: bool = data.get("stream", False)
34+
request: bool = data.get("request", False)
35+
reply: str = data.get("reply", "")
36+
timeout: float = data.get("timeout", 0.5)
37+
probe: bool = data.get("probe", True)
38+
39+
# Streaming
40+
if stream:
41+
msgs: AsyncIterator[RawMessage] = await backend.publish(
42+
addr=addr, msg=msg, stream=stream, probe=probe
43+
)
44+
45+
async def event_stream() -> AsyncIterator[str]:
46+
try:
47+
async for raw in msgs:
48+
yield dict(data=raw.encode_json())
49+
except BaseError as exc:
50+
yield dict(event="error", data=exc.encode_json())
51+
52+
return EventSourceResponse(event_stream())
53+
54+
# Non-streaming
3055
try:
3156
resp: RawMessage | None = await backend.publish(
32-
addr=Address.model_validate(data["addr"]),
33-
msg=RawMessage.model_validate(data["msg"]),
34-
request=data.get("request", False),
35-
stream=data.get("stream", False),
36-
reply=data.get("reply", ""),
37-
timeout=data.get("timeout", 0.5),
38-
probe=data.get("probe", True),
57+
addr=addr,
58+
msg=msg,
59+
stream=stream,
60+
request=request,
61+
reply=reply,
62+
timeout=timeout,
63+
probe=probe,
3964
)
4065
except BaseError as exc:
4166
return JSONResponse(exc.encode(mode="json"), status_code=404)
@@ -46,24 +71,6 @@ async def publish(request):
4671
return JSONResponse(resp.encode(mode="json"))
4772

4873

49-
async def publish_multi(request):
50-
data: dict = await request.json()
51-
msgs = backend.publish_multi(
52-
addr=Address.model_validate(data["addr"]),
53-
msg=RawMessage.model_validate(data["msg"]),
54-
probe=data.get("probe", True),
55-
)
56-
57-
async def event_stream() -> AsyncIterator[str]:
58-
try:
59-
async for raw in msgs:
60-
yield dict(data=raw.encode_json())
61-
except BaseError as exc:
62-
yield dict(event="error", data=exc.encode_json())
63-
64-
return EventSourceResponse(event_stream())
65-
66-
6774
async def subscribe(request):
6875
data: dict = await request.json()
6976
msgs: AsyncIterator[RawMessage] = backend.subscribe(
@@ -85,7 +92,6 @@ async def new_reply_topic(request):
8592

8693
routes = [
8794
Route("/publish", publish, methods=["POST"]),
88-
Route("/publish_multi", publish_multi, methods=["POST"]),
8995
Route("/subscribe", subscribe, methods=["POST"]),
9096
Route("/reply-topics", new_reply_topic, methods=["POST"]),
9197
]

0 commit comments

Comments
 (0)