|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import abc |
4 | | -from dataclasses import dataclass |
| 4 | +import dataclasses |
5 | 5 | import enum |
6 | 6 | from typing import Any, AsyncIterator, Awaitable, Callable, Type |
7 | 7 | import uuid |
@@ -246,34 +246,40 @@ async def new_reply_topic(self) -> str: |
246 | 246 | pass |
247 | 247 |
|
248 | 248 |
|
249 | | -@dataclass |
| 249 | +@dataclasses.dataclass |
250 | 250 | class AgentSpec: |
251 | 251 | """The specification of an agent.""" |
252 | 252 |
|
253 | 253 | name: str |
254 | 254 | constructor: Constructor |
255 | 255 | description: str = "" |
256 | 256 |
|
257 | | - _runtime: Runtime | None = None |
| 257 | + __runtime: Runtime | None = dataclasses.field(default=None, init=False) |
| 258 | + |
| 259 | + def register(self, runtime: Runtime) -> None: |
| 260 | + """Register the agent specification to a runtime.""" |
| 261 | + self.__runtime = runtime |
258 | 262 |
|
259 | 263 | async def run(self, msg: RawMessage, timeout: float = 0.5) -> RawMessage: |
260 | | - self._assert_runtime() |
| 264 | + """Create an agent and run it with the given message.""" |
| 265 | + self.__assert_runtime() |
261 | 266 |
|
262 | 267 | addr = Address(name=self.name, id=uuid.uuid4().hex) |
263 | | - return await self._runtime.channel.publish( |
| 268 | + return await self.__runtime.channel.publish( |
264 | 269 | addr, msg, request=True, timeout=timeout |
265 | 270 | ) |
266 | 271 |
|
267 | 272 | async def run_stream(self, msg: RawMessage) -> AsyncIterator[RawMessage]: |
268 | | - self._assert_runtime() |
| 273 | + """Create an agent and run it with the given message.""" |
| 274 | + self.__assert_runtime() |
269 | 275 |
|
270 | 276 | addr = Address(name=self.name, id=uuid.uuid4().hex) |
271 | | - result = self._runtime.channel.publish_multi(addr, msg) |
| 277 | + result = self.__runtime.channel.publish_multi(addr, msg) |
272 | 278 | async for chunk in result: |
273 | 279 | yield chunk |
274 | 280 |
|
275 | | - def _assert_runtime(self) -> None: |
276 | | - if self._runtime is None: |
| 281 | + def __assert_runtime(self) -> None: |
| 282 | + if self.__runtime is None: |
277 | 283 | raise ValueError(f"AgentSpec {self.name} is not registered to a runtime.") |
278 | 284 |
|
279 | 285 |
|
|
0 commit comments