Skip to content

Commit bc15097

Browse files
authored
feat(server): add optional provider_id reference to contexts (#1370)
Signed-off-by: Radek Ježek <[email protected]>
1 parent 33dd4d7 commit bc15097

File tree

10 files changed

+160
-7
lines changed

10 files changed

+160
-7
lines changed

apps/beeai-cli/src/beeai_cli/commands/agent.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,11 @@ async def run_agent(
791791
await ensure_llm_provider()
792792
provider = select_provider(search_path, providers=providers)
793793

794-
context = await Context.create(metadata={"provider_id": provider.id, "agent_name": provider.agent_card.name})
794+
context = await Context.create(
795+
provider_id=provider.id,
796+
# TODO: remove metadata after UI migration
797+
metadata={"provider_id": provider.id, "agent_name": provider.agent_card.name},
798+
)
795799
context_token = await context.generate_token(
796800
grant_global_permissions=Permissions(llm={"*"}, embeddings={"*"}, a2a_proxy={"*"}),
797801
grant_context_permissions=ContextPermissions(files={"*"}, vector_stores={"*"}, context_data={"*"}),

apps/beeai-sdk-py/src/beeai_sdk/platform/context.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,26 @@ class Context(pydantic.BaseModel):
6363
updated_at: pydantic.AwareDatetime
6464
last_active_at: pydantic.AwareDatetime
6565
created_by: str
66+
provider_id: str | None = None
6667
metadata: Metadata | None = None
6768

6869
@staticmethod
6970
async def create(
7071
*,
7172
metadata: Metadata | None = None,
73+
provider_id: str | None = None,
7274
client: PlatformClient | None = None,
7375
) -> Context:
7476
async with client or get_platform_client() as client:
7577
return pydantic.TypeAdapter(Context).validate_python(
76-
(await client.post(url="/api/v1/contexts", json={"metadata": metadata})).raise_for_status().json()
78+
(
79+
await client.post(
80+
url="/api/v1/contexts",
81+
json=filter_dict({"metadata": metadata, "provider_id": provider_id}),
82+
)
83+
)
84+
.raise_for_status()
85+
.json()
7786
)
7887

7988
@staticmethod
@@ -85,6 +94,7 @@ async def list(
8594
order: Literal["asc"] | Literal["desc"] | None = None,
8695
order_by: Literal["created_at"] | Literal["updated_at"] | None = None,
8796
include_empty: bool = True,
97+
provider_id: str | None = None,
8898
) -> PaginatedResult[Context]:
8999
# `self` has a weird type so that you can call both `instance.get()` to update an instance, or `File.get("123")` to obtain a new instance
90100
async with client or get_platform_client() as client:
@@ -99,6 +109,7 @@ async def list(
99109
"order": order,
100110
"order_by": order_by,
101111
"include_empty": include_empty,
112+
"provider_id": provider_id,
102113
}
103114
),
104115
)

apps/beeai-server/src/beeai_server/api/routes/contexts.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ async def create_context(
4040
context_service: ContextServiceDependency,
4141
user: Annotated[AuthorizedUser, Depends(RequiresPermissions(contexts={"write"}))],
4242
) -> EntityModel[Context]:
43-
return EntityModel(await context_service.create(user=user.user, metadata=request.metadata))
43+
return EntityModel(
44+
await context_service.create(user=user.user, metadata=request.metadata, provider_id=request.provider_id)
45+
)
4446

4547

4648
@router.get("")
@@ -49,7 +51,9 @@ async def list_context(
4951
user: Annotated[AuthorizedUser, Depends(RequiresPermissions(contexts={"read"}))],
5052
query: Annotated[ContextListQuery, Query()],
5153
) -> PaginatedResult[Context]:
52-
return await context_service.list(user=user.user, pagination=query, include_empty=query.include_empty)
54+
return await context_service.list(
55+
user=user.user, pagination=query, include_empty=query.include_empty, provider_id=query.provider_id
56+
)
5357

5458

5559
@router.get("/{context_id}")

apps/beeai-server/src/beeai_server/api/schema/contexts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
22
# SPDX-License-Identifier: Apache-2.0
33
from typing import Literal
4+
from uuid import UUID
45

56
from pydantic import AwareDatetime, BaseModel, Field, RootModel
67

@@ -14,6 +15,7 @@ class ContextCreateRequest(BaseModel):
1415
"""Request schema for context creation."""
1516

1617
metadata: Metadata | None = None
18+
provider_id: UUID | None = None
1719

1820

1921
class ContextUpdateRequest(BaseModel):
@@ -30,6 +32,7 @@ class ContextPatchMetadataRequest(BaseModel):
3032

3133
class ContextListQuery(PaginationQuery):
3234
include_empty: bool = True
35+
provider_id: UUID | None = None
3336

3437

3538
class ContextPermissionsGrant(BaseModel):

apps/beeai-server/src/beeai_server/domain/models/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Context(BaseModel):
3333
updated_at: AwareDatetime = Field(default_factory=utc_now)
3434
last_active_at: AwareDatetime = Field(default_factory=utc_now)
3535
created_by: UUID
36+
provider_id: UUID | None = None
3637
metadata: Metadata | None = None
3738

3839

apps/beeai-server/src/beeai_server/domain/repositories/context.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ async def list_paginated(
2020
self,
2121
*,
2222
user_id: UUID | None = None,
23+
provider_id: UUID | None = None,
2324
limit: int = 20,
2425
page_token: UUID | None = None,
2526
order: str = "desc",
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""add optional provider_id relationship to context
5+
6+
Revision ID: 28725d931ca5
7+
Revises: c69e8ad9a82a
8+
Create Date: 2025-10-22 13:40:10.921827
9+
10+
"""
11+
12+
from collections.abc import Sequence
13+
14+
import sqlalchemy as sa
15+
from alembic import op
16+
17+
# revision identifiers, used by Alembic.
18+
revision: str = "28725d931ca5"
19+
down_revision: str | None = "c69e8ad9a82a"
20+
branch_labels: str | Sequence[str] | None = None
21+
depends_on: str | Sequence[str] | None = None
22+
23+
24+
def upgrade() -> None:
25+
"""Upgrade schema."""
26+
# ### commands auto generated by Alembic - please adjust! ###
27+
op.add_column("contexts", sa.Column("provider_id", sa.UUID(), nullable=True))
28+
op.create_foreign_key(None, "contexts", "providers", ["provider_id"], ["id"], ondelete="CASCADE")
29+
# ### end Alembic commands ###
30+
31+
32+
def downgrade() -> None:
33+
"""Downgrade schema."""
34+
# ### commands auto generated by Alembic - please adjust! ###
35+
op.drop_constraint("contexts_provider_id_fkey", "contexts", type_="foreignkey")
36+
op.drop_column("contexts", "provider_id")
37+
# ### end Alembic commands ###

apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/context.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
Column("updated_at", DateTime(timezone=True), nullable=False),
3939
Column("last_active_at", DateTime(timezone=True), nullable=True),
4040
Column("created_by", ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
41+
Column("provider_id", ForeignKey("providers.id", ondelete="CASCADE"), nullable=True),
4142
Column("metadata", JSON, nullable=True),
4243
)
4344

@@ -74,6 +75,7 @@ async def list_paginated(
7475
self,
7576
*,
7677
user_id: UUID | None = None,
78+
provider_id: UUID | None = None,
7779
limit: int = 20,
7880
page_token: UUID | None = None,
7981
order: str = "desc",
@@ -83,6 +85,8 @@ async def list_paginated(
8385
query = contexts_table.select()
8486
if user_id is not None:
8587
query = query.where(contexts_table.c.created_by == user_id)
88+
if provider_id is not None:
89+
query = query.where(contexts_table.c.provider_id == provider_id)
8690
if not include_empty:
8791
# Use EXISTS subquery to find contexts that have at least one history record
8892
subquery = select(context_history_table.c.context_id).where(
@@ -113,6 +117,7 @@ async def create(self, *, context: Context) -> None:
113117
updated_at=context.updated_at,
114118
last_active_at=context.last_active_at,
115119
created_by=context.created_by,
120+
provider_id=context.provider_id,
116121
metadata=context.metadata,
117122
)
118123
await self._connection.execute(query)
@@ -219,5 +224,6 @@ def _row_to_context(self, row: Row) -> Context:
219224
updated_at=row.updated_at,
220225
last_active_at=row.last_active_at,
221226
created_by=row.created_by,
227+
provider_id=row.provider_id,
222228
metadata=row.metadata,
223229
)

apps/beeai-server/src/beeai_server/service_layer/services/contexts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ def __init__(self, uow: IUnitOfWorkFactory, configuration: Configuration, object
3333
self._configuration = configuration
3434
self._expire_resources_after = timedelta(days=configuration.context.resource_expire_after_days)
3535

36-
async def create(self, *, user: User, metadata: Metadata) -> Context:
37-
context = Context(created_by=user.id, metadata=metadata)
36+
async def create(self, *, user: User, metadata: Metadata, provider_id: UUID | None = None) -> Context:
37+
context = Context(created_by=user.id, metadata=metadata, provider_id=provider_id)
3838
async with self._uow() as uow:
3939
await uow.contexts.create(context=context)
4040
await uow.commit()
@@ -45,11 +45,12 @@ async def get(self, *, context_id: UUID, user: User) -> Context:
4545
return await uow.contexts.get(context_id=context_id, user_id=user.id)
4646

4747
async def list(
48-
self, *, user: User, pagination: PaginationQuery, include_empty: bool = True
48+
self, *, user: User, pagination: PaginationQuery, include_empty: bool = True, provider_id: UUID | None = None
4949
) -> PaginatedResult[Context]:
5050
async with self._uow() as uow:
5151
return await uow.contexts.list_paginated(
5252
user_id=user.id,
53+
provider_id=provider_id,
5354
limit=pagination.limit,
5455
page_token=pagination.page_token,
5556
order=pagination.order,

apps/beeai-server/tests/e2e/routes/test_contexts.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,88 @@ async def test_context_update_and_patch(subtests):
184184

185185
with subtests.test("exceed metadata size"), pytest.raises(HTTPStatusError):
186186
await context.patch_metadata(metadata={str(i): str(i) for i in range(15)})
187+
188+
189+
@pytest.mark.usefixtures("clean_up", "setup_platform_client")
190+
async def test_context_provider_filtering(subtests):
191+
"""Test creating contexts with provider_id and filtering by provider_id."""
192+
from a2a.types import AgentCapabilities, AgentCard
193+
from beeai_sdk.platform import Provider
194+
195+
provider1 = None
196+
provider2 = None
197+
context_with_provider1 = None
198+
context_with_provider2 = None
199+
context_without_provider = None
200+
201+
with subtests.test("create dummy providers"):
202+
# Create first dummy provider with network URL
203+
agent_card1 = AgentCard(
204+
name="Test Provider 1",
205+
description="First test provider",
206+
url="http://localhost:9001/",
207+
version="1.0.0",
208+
default_input_modes=["text"],
209+
default_output_modes=["text"],
210+
capabilities=AgentCapabilities(),
211+
skills=[],
212+
)
213+
provider1 = await Provider.create(location="http://localhost:9001", agent_card=agent_card1)
214+
assert provider1.id is not None
215+
216+
# Create second dummy provider with network URL
217+
agent_card2 = AgentCard(
218+
name="Test Provider 2",
219+
description="Second test provider",
220+
url="http://localhost:9002/",
221+
version="1.0.0",
222+
default_input_modes=["text"],
223+
default_output_modes=["text"],
224+
capabilities=AgentCapabilities(),
225+
skills=[],
226+
)
227+
provider2 = await Provider.create(location="http://localhost:9002", agent_card=agent_card2)
228+
assert provider2.id is not None
229+
assert provider1.id != provider2.id
230+
231+
with subtests.test("create contexts with and without provider_id"):
232+
# Create context associated with provider1
233+
context_with_provider1 = await Context.create(metadata={"name": "context_provider1"}, provider_id=provider1.id)
234+
assert context_with_provider1.provider_id == provider1.id
235+
236+
# Create context associated with provider2
237+
context_with_provider2 = await Context.create(metadata={"name": "context_provider2"}, provider_id=provider2.id)
238+
assert context_with_provider2.provider_id == provider2.id
239+
240+
# Create context without provider
241+
context_without_provider = await Context.create(metadata={"name": "context_no_provider"})
242+
assert context_without_provider.provider_id is None
243+
244+
with subtests.test("list all contexts without filter"):
245+
all_contexts = await Context.list()
246+
assert len(all_contexts.items) == 3
247+
context_ids = [ctx.id for ctx in all_contexts.items]
248+
assert context_with_provider1.id in context_ids
249+
assert context_with_provider2.id in context_ids
250+
assert context_without_provider.id in context_ids
251+
252+
with subtests.test("filter contexts by provider1"):
253+
provider1_contexts = await Context.list(provider_id=provider1.id)
254+
assert len(provider1_contexts.items) == 1
255+
assert provider1_contexts.items[0].id == context_with_provider1.id
256+
assert provider1_contexts.items[0].provider_id == provider1.id
257+
258+
with subtests.test("filter contexts by provider2"):
259+
provider2_contexts = await Context.list(provider_id=provider2.id)
260+
assert len(provider2_contexts.items) == 1
261+
assert provider2_contexts.items[0].id == context_with_provider2.id
262+
assert provider2_contexts.items[0].provider_id == provider2.id
263+
264+
with subtests.test("filter by non-existent provider returns empty list"):
265+
nonexistent_provider_id = str(uuid.uuid4())
266+
no_contexts = await Context.list(provider_id=nonexistent_provider_id)
267+
assert len(no_contexts.items) == 0
268+
269+
with subtests.test("get context includes provider_id"):
270+
fetched_context = await Context.get(context_with_provider1.id)
271+
assert fetched_context.provider_id == provider1.id

0 commit comments

Comments
 (0)