Skip to content

Commit 3b6b0db

Browse files
committed
Add AgentSpec
1 parent 6bb2750 commit 3b6b0db

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

coagent/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .types import (
1212
Address,
1313
Agent,
14+
AgentSpec,
1415
Constructor,
1516
Channel,
1617
MessageHeader,

coagent/core/runtime.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from .messages import StopIteration, Error
99
from .factory import Factory, DeleteAgent
1010
from .types import (
11+
AgentSpec,
1112
Channel,
12-
Constructor,
1313
Runtime,
1414
Address,
1515
RawMessage,
@@ -35,19 +35,21 @@ async def stop(self) -> None:
3535
await self.deregister()
3636
await self._channel.close()
3737

38-
async def register(
39-
self, name: str, constructor: Constructor, description: str = ""
40-
) -> None:
38+
async def register_spec(self, spec: AgentSpec) -> None:
39+
spec._runtime = self
40+
4141
if self._discovery:
42-
await self._discovery.register(name, constructor, description)
42+
await self._discovery.register(
43+
spec.name, spec.constructor, spec.description
44+
)
4345

44-
if name in self._factories:
45-
raise ValueError(f"Agent type {name} already registered")
46+
if spec.name in self._factories:
47+
raise ValueError(f"Agent type {spec.name} already registered")
4648

47-
factory = Factory(name, constructor)
49+
factory = Factory(spec.name, spec.constructor)
4850
# We MUST set the channel and address manually.
49-
factory.init(self._channel, Address(name=name))
50-
self._factories[name] = factory
51+
factory.init(self._channel, Address(name=spec.name))
52+
self._factories[spec.name] = factory
5153

5254
await factory.start()
5355

coagent/core/types.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

33
import abc
4+
from dataclasses import dataclass
45
import enum
56
from typing import Any, AsyncIterator, Awaitable, Callable, Type
7+
import uuid
68

79
from pydantic import BaseModel, Field
810

@@ -238,6 +240,26 @@ async def new_reply_topic(self) -> str:
238240
pass
239241

240242

243+
@dataclass
244+
class AgentSpec:
245+
"""The specification of an agent."""
246+
247+
name: str
248+
constructor: Constructor
249+
description: str = ""
250+
251+
_runtime: Runtime | None = None
252+
253+
async def run_stream(self, msg: RawMessage) -> AsyncIterator[RawMessage]:
254+
if self._runtime is None:
255+
raise ValueError(f"AgentSpec {self.name} is not registered to a runtime.")
256+
257+
addr = Address(name=self.name, id=uuid.uuid4().hex)
258+
result = self._runtime.channel.publish_multi(addr, msg)
259+
async for chunk in result:
260+
yield chunk
261+
262+
241263
class Runtime(abc.ABC):
242264
async def __aenter__(self):
243265
await self.start()
@@ -254,10 +276,15 @@ async def start(self) -> None:
254276
async def stop(self) -> None:
255277
pass
256278

257-
@abc.abstractmethod
258279
async def register(
259280
self, name: str, constructor: Constructor, description: str = ""
260281
) -> None:
282+
await self.register_spec(
283+
AgentSpec(name=name, constructor=constructor, description=description)
284+
)
285+
286+
@abc.abstractmethod
287+
async def register_spec(self, spec: AgentSpec) -> None:
261288
pass
262289

263290
@abc.abstractmethod

0 commit comments

Comments
 (0)