Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion apps/beeai-cli/src/beeai_cli/commands/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,11 @@ async def run_agent(
await ensure_llm_provider()
provider = select_provider(search_path, providers=providers)

context = await Context.create(metadata={"provider_id": provider.id, "agent_name": provider.agent_card.name})
context = await Context.create(
provider_id=provider.id,
# TODO: remove metadata after UI migration
metadata={"provider_id": provider.id, "agent_name": provider.agent_card.name},
)
context_token = await context.generate_token(
grant_global_permissions=Permissions(llm={"*"}, embeddings={"*"}, a2a_proxy={"*"}),
grant_context_permissions=ContextPermissions(files={"*"}, vector_stores={"*"}, context_data={"*"}),
Expand Down
13 changes: 12 additions & 1 deletion apps/beeai-sdk-py/src/beeai_sdk/platform/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,26 @@ class Context(pydantic.BaseModel):
updated_at: pydantic.AwareDatetime
last_active_at: pydantic.AwareDatetime
created_by: str
provider_id: str | None = None
metadata: Metadata | None = None

@staticmethod
async def create(
*,
metadata: Metadata | None = None,
provider_id: str | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider using Optional[str] instead of str | None for better readability and consistency with other type hints in this file.

Suggested change
provider_id: str | None = None,
provider_id: typing.Optional[str] = None,

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tell me, where else in this file is Optional type hint?

client: PlatformClient | None = None,
) -> Context:
async with client or get_platform_client() as client:
return pydantic.TypeAdapter(Context).validate_python(
(await client.post(url="/api/v1/contexts", json={"metadata": metadata})).raise_for_status().json()
(
await client.post(
url="/api/v1/contexts",
json=filter_dict({"metadata": metadata, "provider_id": provider_id}),
)
)
.raise_for_status()
.json()
)

@staticmethod
Expand All @@ -85,6 +94,7 @@ async def list(
order: Literal["asc"] | Literal["desc"] | None = None,
order_by: Literal["created_at"] | Literal["updated_at"] | None = None,
include_empty: bool = True,
provider_id: str | None = None,
) -> PaginatedResult[Context]:
# `self` has a weird type so that you can call both `instance.get()` to update an instance, or `File.get("123")` to obtain a new instance
async with client or get_platform_client() as client:
Expand All @@ -99,6 +109,7 @@ async def list(
"order": order,
"order_by": order_by,
"include_empty": include_empty,
"provider_id": provider_id,
}
),
)
Expand Down
8 changes: 6 additions & 2 deletions apps/beeai-server/src/beeai_server/api/routes/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ async def create_context(
context_service: ContextServiceDependency,
user: Annotated[AuthorizedUser, Depends(RequiresPermissions(contexts={"write"}))],
) -> EntityModel[Context]:
return EntityModel(await context_service.create(user=user.user, metadata=request.metadata))
return EntityModel(
await context_service.create(user=user.user, metadata=request.metadata, provider_id=request.provider_id)
)


@router.get("")
Expand All @@ -49,7 +51,9 @@ async def list_context(
user: Annotated[AuthorizedUser, Depends(RequiresPermissions(contexts={"read"}))],
query: Annotated[ContextListQuery, Query()],
) -> PaginatedResult[Context]:
return await context_service.list(user=user.user, pagination=query, include_empty=query.include_empty)
return await context_service.list(
user=user.user, pagination=query, include_empty=query.include_empty, provider_id=query.provider_id
)


@router.get("/{context_id}")
Expand Down
3 changes: 3 additions & 0 deletions apps/beeai-server/src/beeai_server/api/schema/contexts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0
from typing import Literal
from uuid import UUID

from pydantic import AwareDatetime, BaseModel, Field, RootModel

Expand All @@ -14,6 +15,7 @@ class ContextCreateRequest(BaseModel):
"""Request schema for context creation."""

metadata: Metadata | None = None
provider_id: UUID | None = None


class ContextUpdateRequest(BaseModel):
Expand All @@ -30,6 +32,7 @@ class ContextPatchMetadataRequest(BaseModel):

class ContextListQuery(PaginationQuery):
include_empty: bool = True
provider_id: UUID | None = None


class ContextPermissionsGrant(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Context(BaseModel):
updated_at: AwareDatetime = Field(default_factory=utc_now)
last_active_at: AwareDatetime = Field(default_factory=utc_now)
created_by: UUID
provider_id: UUID | None = None
metadata: Metadata | None = None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ async def list_paginated(
self,
*,
user_id: UUID | None = None,
provider_id: UUID | None = None,
limit: int = 20,
page_token: UUID | None = None,
order: str = "desc",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
# SPDX-License-Identifier: Apache-2.0

"""add optional provider_id relationship to context

Revision ID: 28725d931ca5
Revises: c69e8ad9a82a
Create Date: 2025-10-22 13:40:10.921827

"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "28725d931ca5"
down_revision: str | None = "c69e8ad9a82a"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("contexts", sa.Column("provider_id", sa.UUID(), nullable=True))
op.create_foreign_key(None, "contexts", "providers", ["provider_id"], ["id"], ondelete="CASCADE")
# ### end Alembic commands ###


def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint("contexts_provider_id_fkey", "contexts", type_="foreignkey")
op.drop_column("contexts", "provider_id")
# ### end Alembic commands ###
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Column("updated_at", DateTime(timezone=True), nullable=False),
Column("last_active_at", DateTime(timezone=True), nullable=True),
Column("created_by", ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
Column("provider_id", ForeignKey("providers.id", ondelete="CASCADE"), nullable=True),
Column("metadata", JSON, nullable=True),
)

Expand Down Expand Up @@ -74,6 +75,7 @@ async def list_paginated(
self,
*,
user_id: UUID | None = None,
provider_id: UUID | None = None,
limit: int = 20,
page_token: UUID | None = None,
order: str = "desc",
Expand All @@ -83,6 +85,8 @@ async def list_paginated(
query = contexts_table.select()
if user_id is not None:
query = query.where(contexts_table.c.created_by == user_id)
if provider_id is not None:
query = query.where(contexts_table.c.provider_id == provider_id)
if not include_empty:
# Use EXISTS subquery to find contexts that have at least one history record
subquery = select(context_history_table.c.context_id).where(
Expand Down Expand Up @@ -113,6 +117,7 @@ async def create(self, *, context: Context) -> None:
updated_at=context.updated_at,
last_active_at=context.last_active_at,
created_by=context.created_by,
provider_id=context.provider_id,
metadata=context.metadata,
)
await self._connection.execute(query)
Expand Down Expand Up @@ -219,5 +224,6 @@ def _row_to_context(self, row: Row) -> Context:
updated_at=row.updated_at,
last_active_at=row.last_active_at,
created_by=row.created_by,
provider_id=row.provider_id,
metadata=row.metadata,
)
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(self, uow: IUnitOfWorkFactory, configuration: Configuration, object
self._configuration = configuration
self._expire_resources_after = timedelta(days=configuration.context.resource_expire_after_days)

async def create(self, *, user: User, metadata: Metadata) -> Context:
context = Context(created_by=user.id, metadata=metadata)
async def create(self, *, user: User, metadata: Metadata, provider_id: UUID | None = None) -> Context:
context = Context(created_by=user.id, metadata=metadata, provider_id=provider_id)
async with self._uow() as uow:
await uow.contexts.create(context=context)
await uow.commit()
Expand All @@ -45,11 +45,12 @@ async def get(self, *, context_id: UUID, user: User) -> Context:
return await uow.contexts.get(context_id=context_id, user_id=user.id)

async def list(
self, *, user: User, pagination: PaginationQuery, include_empty: bool = True
self, *, user: User, pagination: PaginationQuery, include_empty: bool = True, provider_id: UUID | None = None
) -> PaginatedResult[Context]:
async with self._uow() as uow:
return await uow.contexts.list_paginated(
user_id=user.id,
provider_id=provider_id,
limit=pagination.limit,
page_token=pagination.page_token,
order=pagination.order,
Expand Down
85 changes: 85 additions & 0 deletions apps/beeai-server/tests/e2e/routes/test_contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,88 @@ async def test_context_update_and_patch(subtests):

with subtests.test("exceed metadata size"), pytest.raises(HTTPStatusError):
await context.patch_metadata(metadata={str(i): str(i) for i in range(15)})


@pytest.mark.usefixtures("clean_up", "setup_platform_client")
async def test_context_provider_filtering(subtests):
"""Test creating contexts with provider_id and filtering by provider_id."""
from a2a.types import AgentCapabilities, AgentCard
from beeai_sdk.platform import Provider

provider1 = None
provider2 = None
context_with_provider1 = None
context_with_provider2 = None
context_without_provider = None

with subtests.test("create dummy providers"):
# Create first dummy provider with network URL
agent_card1 = AgentCard(
name="Test Provider 1",
description="First test provider",
url="http://localhost:9001/",
version="1.0.0",
default_input_modes=["text"],
default_output_modes=["text"],
capabilities=AgentCapabilities(),
skills=[],
)
provider1 = await Provider.create(location="http://localhost:9001", agent_card=agent_card1)
assert provider1.id is not None

# Create second dummy provider with network URL
agent_card2 = AgentCard(
name="Test Provider 2",
description="Second test provider",
url="http://localhost:9002/",
version="1.0.0",
default_input_modes=["text"],
default_output_modes=["text"],
capabilities=AgentCapabilities(),
skills=[],
)
provider2 = await Provider.create(location="http://localhost:9002", agent_card=agent_card2)
assert provider2.id is not None
assert provider1.id != provider2.id

with subtests.test("create contexts with and without provider_id"):
# Create context associated with provider1
context_with_provider1 = await Context.create(metadata={"name": "context_provider1"}, provider_id=provider1.id)
assert context_with_provider1.provider_id == provider1.id

# Create context associated with provider2
context_with_provider2 = await Context.create(metadata={"name": "context_provider2"}, provider_id=provider2.id)
assert context_with_provider2.provider_id == provider2.id

# Create context without provider
context_without_provider = await Context.create(metadata={"name": "context_no_provider"})
assert context_without_provider.provider_id is None

with subtests.test("list all contexts without filter"):
all_contexts = await Context.list()
assert len(all_contexts.items) == 3
context_ids = [ctx.id for ctx in all_contexts.items]
assert context_with_provider1.id in context_ids
assert context_with_provider2.id in context_ids
assert context_without_provider.id in context_ids

with subtests.test("filter contexts by provider1"):
provider1_contexts = await Context.list(provider_id=provider1.id)
assert len(provider1_contexts.items) == 1
assert provider1_contexts.items[0].id == context_with_provider1.id
assert provider1_contexts.items[0].provider_id == provider1.id

with subtests.test("filter contexts by provider2"):
provider2_contexts = await Context.list(provider_id=provider2.id)
assert len(provider2_contexts.items) == 1
assert provider2_contexts.items[0].id == context_with_provider2.id
assert provider2_contexts.items[0].provider_id == provider2.id

with subtests.test("filter by non-existent provider returns empty list"):
nonexistent_provider_id = str(uuid.uuid4())
no_contexts = await Context.list(provider_id=nonexistent_provider_id)
assert len(no_contexts.items) == 0

with subtests.test("get context includes provider_id"):
fetched_context = await Context.get(context_with_provider1.id)
assert fetched_context.provider_id == provider1.id