Skip to content

Commit e7ab680

Browse files
committed
platform extension wip
Signed-off-by: Radek Ježek <radek.jezek@ibm.com>
1 parent c0ccc0a commit e7ab680

File tree

46 files changed

+548
-151
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+548
-151
lines changed

agents/community/gpt-researcher/gpt_researcher_agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from beeai_sdk.a2a.extensions import TrajectoryExtensionServer, TrajectoryExtensionSpec, AgentDetail
1313
from beeai_sdk.a2a.types import RunYield
1414
from beeai_sdk.server import Server
15-
from beeai_sdk.server.context import Context
15+
from beeai_sdk.server.context import RunContext
1616

1717
server = Server()
1818

@@ -59,7 +59,7 @@
5959
],
6060
)
6161
async def gpt_researcher(
62-
message: Message, context: Context, trajectory: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()]
62+
message: Message, context: RunContext, trajectory: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()]
6363
) -> AsyncGenerator[RunYield, None]:
6464
"""
6565
The agent conducts in-depth local and web research using a language model to generate comprehensive reports with

agents/official/beeai-framework/chat/src/chat/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from beeai_sdk.a2a.types import AgentMessage
4343
from beeai_sdk.server import Server
44-
from beeai_sdk.server.context import Context
44+
from beeai_sdk.server.context import RunContext
4545
from chat.helpers.citations import extract_citations
4646
from chat.helpers.trajectory import TrajectoryContent
4747
from openinference.instrumentation.beeai import BeeAIInstrumentor
@@ -136,7 +136,7 @@
136136
)
137137
async def chat(
138138
message: Message,
139-
context: Context,
139+
context: RunContext,
140140
trajectory: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()],
141141
citation: Annotated[CitationExtensionServer, CitationExtensionSpec()],
142142
):

agents/official/beeai-framework/chat/uv.lock

Lines changed: 49 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

apps/beeai-sdk/examples/dependencies.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,22 +10,30 @@
1010
from beeai_sdk.a2a.extensions import LLMServiceExtensionServer, LLMServiceExtensionSpec
1111
from beeai_sdk.a2a.extensions.ui.trajectory import TrajectoryExtensionServer, TrajectoryExtensionSpec
1212
from beeai_sdk.a2a.types import RunYield
13+
from beeai_sdk.platform import File
1314
from beeai_sdk.server import Server
14-
from beeai_sdk.server.context import Context
15+
from beeai_sdk.server.context import RunContext
16+
from src.beeai_sdk.a2a.extensions.services.platform import (
17+
PlatformApiExtensionServer,
18+
PlatformApiExtensionSpec,
19+
)
1520

1621
server = Server()
1722

1823

1924
@server.agent()
2025
async def dependent_agent(
2126
message: Message,
22-
context: Context,
27+
context: RunContext,
2328
trajectory: Annotated[TrajectoryExtensionServer, TrajectoryExtensionSpec()],
2429
# does not typecheck, does not ruff check
2530
llm: Annotated[LLMServiceExtensionServer, LLMServiceExtensionSpec.single_demand()],
31+
_: Annotated[PlatformApiExtensionServer, PlatformApiExtensionSpec()],
2632
) -> AsyncGenerator[RunYield, Message]:
2733
"""Awaits a user message"""
2834

35+
await File.create(filename="my_file.txt", content=b"hello world", content_type="text/plain")
36+
2937
yield trajectory.trajectory_metadata(title="context_param", content=str(context))
3038
yield trajectory.trajectory_metadata(title="message_param", content=str(message.model_dump()))
3139
yield trajectory.message(trajectory_title="llm_param", trajectory_content=str(llm.data))

apps/beeai-sdk/examples/mcp_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@
1010
from beeai_sdk.a2a.extensions.services.mcp import MCPServiceExtensionServer, MCPServiceExtensionSpec
1111
from beeai_sdk.a2a.types import RunYield
1212
from beeai_sdk.server import Server
13-
from beeai_sdk.server.context import Context
13+
from beeai_sdk.server.context import RunContext
1414

1515
server = Server()
1616

1717

1818
@server.agent()
1919
async def mcp_agent(
2020
message: Message,
21-
context: Context,
21+
context: RunContext,
2222
mcp: Annotated[
2323
MCPServiceExtensionServer,
2424
MCPServiceExtensionSpec.single_demand(),

apps/beeai-sdk/pyproject.toml

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ dependencies = [
2020
"tenacity>=9.1.2",
2121
"janus>=2.0.0",
2222
"uvloop>=0.21.0",
23-
"httpx", # version determined by a2a-sdk
23+
"httpx", # version determined by a2a-sdk
2424
"mcp>=1.12.3",
2525
]
2626

@@ -42,18 +42,18 @@ build-backend = "uv_build"
4242
line-length = 120
4343
target-version = "py311"
4444
lint.select = [
45-
"E", # pycodestyle errors
46-
"W", # pycodestyle warnings
47-
"F", # pyflakes
48-
"UP", # pyupgrade
49-
"I", # isort
50-
"B", # bugbear
51-
"N", # pep8-naming
52-
"C4", # Comprehensions
53-
"Q", # Quotes
54-
"SIM", # Simplify
55-
"RUF", # Ruff
56-
"TID", # tidy-imports
45+
"E", # pycodestyle errors
46+
"W", # pycodestyle warnings
47+
"F", # pyflakes
48+
"UP", # pyupgrade
49+
"I", # isort
50+
"B", # bugbear
51+
"N", # pep8-naming
52+
"C4", # Comprehensions
53+
"Q", # Quotes
54+
"SIM", # Simplify
55+
"RUF", # Ruff
56+
"TID", # tidy-imports
5757
"ASYNC", # async
5858
# TODO: add "DTZ", # DatetimeZ
5959
# TODO: add "ANN", # annotations
@@ -71,3 +71,5 @@ addopts = "-v"
7171

7272
[tool.pyright]
7373
ignore = ["tests/**"]
74+
venvPath = "."
75+
venv = ".venv"

apps/beeai-sdk/src/beeai_sdk/a2a/extensions/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
if typing.TYPE_CHECKING:
19-
from beeai_sdk.server.context import Context
19+
from beeai_sdk.server.context import RunContext
2020

2121

2222
def _get_generic_args(cls: type, base_class: type) -> tuple[typing.Any, ...]:
@@ -141,14 +141,18 @@ def parse_client_metadata(self, message: a2a.types.Message) -> MetadataFromClien
141141
else pydantic.TypeAdapter(self.MetadataFromClient).validate_python(message.metadata[self.spec.URI])
142142
)
143143

144-
def handle_incoming_message(self, message: a2a.types.Message, context: Context):
144+
def handle_incoming_message(self, message: a2a.types.Message, context: RunContext):
145145
if self._metadata_from_client is None:
146146
self._metadata_from_client = self.parse_client_metadata(message)
147147

148-
def __call__(self, message: a2a.types.Message, context: Context) -> typing.Self:
148+
def __call__(self, message: a2a.types.Message, context: RunContext) -> typing.Self:
149149
self.handle_incoming_message(message, context)
150150
return self
151151

152+
async def initialize(self):
153+
"""Called when entering the agent context after the first message was parsed (__call__ was already called)"""
154+
pass
155+
152156

153157
class BaseExtensionClient(abc.ABC, typing.Generic[ExtensionSpecT, MetadataFromServerT]):
154158
MetadataFromServer: type[MetadataFromServerT]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from beeai_sdk.a2a.extensions.base import BaseExtensionSpec
5+
6+
7+
class ExtensionError(Exception):
8+
extension: BaseExtensionSpec
9+
10+
def __init__(self, spec: BaseExtensionSpec, message: str):
11+
super().__init__(f"Exception in extension '{spec.URI}': \n{message}")
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from __future__ import annotations
5+
6+
import asyncio
7+
import os
8+
from collections.abc import AsyncIterator
9+
from contextlib import asynccontextmanager
10+
from types import NoneType
11+
12+
import a2a.types
13+
import pydantic
14+
from pydantic.networks import HttpUrl
15+
16+
from beeai_sdk.a2a.extensions.base import (
17+
BaseExtensionClient,
18+
BaseExtensionServer,
19+
BaseExtensionSpec,
20+
)
21+
from beeai_sdk.a2a.extensions.exceptions import ExtensionError
22+
from beeai_sdk.platform import use_platform_client
23+
from beeai_sdk.platform.client import PlatformClient
24+
25+
26+
class PlatformApiExtensionMetadata(pydantic.BaseModel):
27+
base_url: HttpUrl | None = None
28+
auth_token: pydantic.Secret[str]
29+
expires_at: pydantic.AwareDatetime | None = None
30+
31+
32+
class PlatformApiExtension(pydantic.BaseModel):
33+
"""
34+
Request authentication token and url to be able to access the beeai platform API
35+
"""
36+
37+
38+
class PlatformApiExtensionParams(pydantic.BaseModel):
39+
auto_use: bool = True
40+
41+
42+
class PlatformApiExtensionSpec(BaseExtensionSpec[PlatformApiExtensionParams]):
43+
URI: str = "https://a2a-extensions.beeai.dev/services/platform_api/v1"
44+
45+
def __init__(self, params: PlatformApiExtensionParams | None = None) -> None:
46+
super().__init__(params or PlatformApiExtensionParams())
47+
48+
49+
class PlatformApiExtensionServer(BaseExtensionServer[PlatformApiExtensionSpec, PlatformApiExtensionMetadata]):
50+
context_id: str | None = None
51+
_should_exit: asyncio.Event = asyncio.Event()
52+
_wait_for_startup: asyncio.Event = asyncio.Event()
53+
_lifecycle_task: asyncio.Task | None = None
54+
_lifecycle_task_exception: Exception | None = None
55+
_initialized: bool = False
56+
57+
def parse_client_metadata(self, message: a2a.types.Message) -> PlatformApiExtensionMetadata | None:
58+
self.context_id = message.context_id
59+
# we assume that the context id is the same ID as the platform context id
60+
# if different IDs are passed, api requests to platform using this token will fail
61+
return super().parse_client_metadata(message)
62+
63+
async def managed_platform_client_lifespan(self):
64+
try:
65+
async with self.use_client():
66+
self._wait_for_startup.set()
67+
await self._should_exit.wait()
68+
except Exception as e:
69+
self._lifecycle_task_exception = e
70+
raise
71+
finally:
72+
self._wait_for_startup.set()
73+
74+
async def initialize(self):
75+
"""Called when entering the agent context after the first message was parsed (__call__ was already called)"""
76+
if self._initialized:
77+
raise ExtensionError(self.spec, "Platform extension was already initialized")
78+
79+
if self.spec.params.auto_use:
80+
self._lifecycle_task = asyncio.create_task(self.managed_platform_client_lifespan())
81+
self._lifecycle_task.add_done_callback(lambda *_args: None)
82+
await self._wait_for_startup.wait()
83+
if self._lifecycle_task_exception:
84+
raise self._lifecycle_task_exception
85+
self._initialized = True
86+
87+
def __del__(self):
88+
self._should_exit.set()
89+
90+
@asynccontextmanager
91+
async def use_client(self) -> AsyncIterator[PlatformClient]:
92+
if not self.data:
93+
raise ExtensionError(self.spec, "Platform extension metadata was not provided")
94+
base_url = str(self.data.base_url or os.getenv("PLATFORM_URL", "http://127.0.0.1:8333"))
95+
auth_token = self.data.auth_token.get_secret_value()
96+
async with use_platform_client(context_id=self.context_id, base_url=base_url, auth_token=auth_token) as client:
97+
yield client
98+
99+
100+
class PlatformApiExtensionClient(BaseExtensionClient[PlatformApiExtensionSpec, NoneType]):
101+
def api_auth_metadata(
102+
self,
103+
*,
104+
auth_token: pydantic.Secret[str] | str,
105+
expires_at: pydantic.AwareDatetime | None = None,
106+
base_url: HttpUrl | None = None,
107+
) -> dict[str, dict[str, str]]:
108+
return {
109+
self.spec.URI: {
110+
**PlatformApiExtensionMetadata(
111+
base_url=base_url,
112+
auth_token=pydantic.Secret("replaced below"),
113+
expires_at=expires_at,
114+
).model_dump(mode="json"),
115+
"auth_token": auth_token if isinstance(auth_token, str) else auth_token.get_secret_value(),
116+
}
117+
}

0 commit comments

Comments
 (0)