Skip to content

Commit 58faca5

Browse files
committed
Improve AgentSpec
1 parent 7320d61 commit 58faca5

File tree

2 files changed

+16
-10
lines changed

2 files changed

+16
-10
lines changed

coagent/core/runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ async def stop(self) -> None:
3636
await self._channel.close()
3737

3838
async def register_spec(self, spec: AgentSpec) -> None:
39-
spec._runtime = self
39+
spec.register(self)
4040

4141
if self._discovery:
4242
await self._discovery.register(

coagent/core/types.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import abc
4-
from dataclasses import dataclass
4+
import dataclasses
55
import enum
66
from typing import Any, AsyncIterator, Awaitable, Callable, Type
77
import uuid
@@ -246,34 +246,40 @@ async def new_reply_topic(self) -> str:
246246
pass
247247

248248

249-
@dataclass
249+
@dataclasses.dataclass
250250
class AgentSpec:
251251
"""The specification of an agent."""
252252

253253
name: str
254254
constructor: Constructor
255255
description: str = ""
256256

257-
_runtime: Runtime | None = None
257+
__runtime: Runtime | None = dataclasses.field(default=None, init=False)
258+
259+
def register(self, runtime: Runtime) -> None:
260+
"""Register the agent specification to a runtime."""
261+
self.__runtime = runtime
258262

259263
async def run(self, msg: RawMessage, timeout: float = 0.5) -> RawMessage:
260-
self._assert_runtime()
264+
"""Create an agent and run it with the given message."""
265+
self.__assert_runtime()
261266

262267
addr = Address(name=self.name, id=uuid.uuid4().hex)
263-
return await self._runtime.channel.publish(
268+
return await self.__runtime.channel.publish(
264269
addr, msg, request=True, timeout=timeout
265270
)
266271

267272
async def run_stream(self, msg: RawMessage) -> AsyncIterator[RawMessage]:
268-
self._assert_runtime()
273+
"""Create an agent and run it with the given message."""
274+
self.__assert_runtime()
269275

270276
addr = Address(name=self.name, id=uuid.uuid4().hex)
271-
result = self._runtime.channel.publish_multi(addr, msg)
277+
result = self.__runtime.channel.publish_multi(addr, msg)
272278
async for chunk in result:
273279
yield chunk
274280

275-
def _assert_runtime(self) -> None:
276-
if self._runtime is None:
281+
def __assert_runtime(self) -> None:
282+
if self.__runtime is None:
277283
raise ValueError(f"AgentSpec {self.name} is not registered to a runtime.")
278284

279285

0 commit comments

Comments
 (0)