diff --git a/apps/beeai-cli/src/beeai_cli/commands/agent.py b/apps/beeai-cli/src/beeai_cli/commands/agent.py index 1b2439843..844fc7a99 100644 --- a/apps/beeai-cli/src/beeai_cli/commands/agent.py +++ b/apps/beeai-cli/src/beeai_cli/commands/agent.py @@ -791,7 +791,11 @@ async def run_agent( await ensure_llm_provider() provider = select_provider(search_path, providers=providers) - context = await Context.create(metadata={"provider_id": provider.id, "agent_name": provider.agent_card.name}) + context = await Context.create( + provider_id=provider.id, + # TODO: remove metadata after UI migration + metadata={"provider_id": provider.id, "agent_name": provider.agent_card.name}, + ) context_token = await context.generate_token( grant_global_permissions=Permissions(llm={"*"}, embeddings={"*"}, a2a_proxy={"*"}), grant_context_permissions=ContextPermissions(files={"*"}, vector_stores={"*"}, context_data={"*"}), diff --git a/apps/beeai-sdk-py/src/beeai_sdk/platform/context.py b/apps/beeai-sdk-py/src/beeai_sdk/platform/context.py index e80915b1d..d64b75bf6 100644 --- a/apps/beeai-sdk-py/src/beeai_sdk/platform/context.py +++ b/apps/beeai-sdk-py/src/beeai_sdk/platform/context.py @@ -63,17 +63,26 @@ class Context(pydantic.BaseModel): updated_at: pydantic.AwareDatetime last_active_at: pydantic.AwareDatetime created_by: str + provider_id: str | None = None metadata: Metadata | None = None @staticmethod async def create( *, metadata: Metadata | None = None, + provider_id: str | None = None, client: PlatformClient | None = None, ) -> Context: async with client or get_platform_client() as client: return pydantic.TypeAdapter(Context).validate_python( - (await client.post(url="/api/v1/contexts", json={"metadata": metadata})).raise_for_status().json() + ( + await client.post( + url="/api/v1/contexts", + json=filter_dict({"metadata": metadata, "provider_id": provider_id}), + ) + ) + .raise_for_status() + .json() ) @staticmethod @@ -85,6 +94,7 @@ async def list( order: Literal["asc"] | Literal["desc"] | None = None, order_by: Literal["created_at"] | Literal["updated_at"] | None = None, include_empty: bool = True, + provider_id: str | None = None, ) -> PaginatedResult[Context]: # `self` has a weird type so that you can call both `instance.get()` to update an instance, or `File.get("123")` to obtain a new instance async with client or get_platform_client() as client: @@ -99,6 +109,7 @@ async def list( "order": order, "order_by": order_by, "include_empty": include_empty, + "provider_id": provider_id, } ), ) diff --git a/apps/beeai-server/src/beeai_server/api/routes/contexts.py b/apps/beeai-server/src/beeai_server/api/routes/contexts.py index 4ae63b4c7..126e4c0c4 100644 --- a/apps/beeai-server/src/beeai_server/api/routes/contexts.py +++ b/apps/beeai-server/src/beeai_server/api/routes/contexts.py @@ -40,7 +40,9 @@ async def create_context( context_service: ContextServiceDependency, user: Annotated[AuthorizedUser, Depends(RequiresPermissions(contexts={"write"}))], ) -> EntityModel[Context]: - return EntityModel(await context_service.create(user=user.user, metadata=request.metadata)) + return EntityModel( + await context_service.create(user=user.user, metadata=request.metadata, provider_id=request.provider_id) + ) @router.get("") @@ -49,7 +51,9 @@ async def list_context( user: Annotated[AuthorizedUser, Depends(RequiresPermissions(contexts={"read"}))], query: Annotated[ContextListQuery, Query()], ) -> PaginatedResult[Context]: - return await context_service.list(user=user.user, pagination=query, include_empty=query.include_empty) + return await context_service.list( + user=user.user, pagination=query, include_empty=query.include_empty, provider_id=query.provider_id + ) @router.get("/{context_id}") diff --git a/apps/beeai-server/src/beeai_server/api/schema/contexts.py b/apps/beeai-server/src/beeai_server/api/schema/contexts.py index fb77c3d0c..53a66ee15 100644 --- a/apps/beeai-server/src/beeai_server/api/schema/contexts.py +++ b/apps/beeai-server/src/beeai_server/api/schema/contexts.py @@ -1,6 +1,7 @@ # Copyright 2025 © BeeAI a Series of LF Projects, LLC # SPDX-License-Identifier: Apache-2.0 from typing import Literal +from uuid import UUID from pydantic import AwareDatetime, BaseModel, Field, RootModel @@ -14,6 +15,7 @@ class ContextCreateRequest(BaseModel): """Request schema for context creation.""" metadata: Metadata | None = None + provider_id: UUID | None = None class ContextUpdateRequest(BaseModel): @@ -30,6 +32,7 @@ class ContextPatchMetadataRequest(BaseModel): class ContextListQuery(PaginationQuery): include_empty: bool = True + provider_id: UUID | None = None class ContextPermissionsGrant(BaseModel): diff --git a/apps/beeai-server/src/beeai_server/domain/models/context.py b/apps/beeai-server/src/beeai_server/domain/models/context.py index 8d2a65361..3440023bf 100644 --- a/apps/beeai-server/src/beeai_server/domain/models/context.py +++ b/apps/beeai-server/src/beeai_server/domain/models/context.py @@ -33,6 +33,7 @@ class Context(BaseModel): updated_at: AwareDatetime = Field(default_factory=utc_now) last_active_at: AwareDatetime = Field(default_factory=utc_now) created_by: UUID + provider_id: UUID | None = None metadata: Metadata | None = None diff --git a/apps/beeai-server/src/beeai_server/domain/repositories/context.py b/apps/beeai-server/src/beeai_server/domain/repositories/context.py index 53a4415f5..3ce845b7b 100644 --- a/apps/beeai-server/src/beeai_server/domain/repositories/context.py +++ b/apps/beeai-server/src/beeai_server/domain/repositories/context.py @@ -20,6 +20,7 @@ async def list_paginated( self, *, user_id: UUID | None = None, + provider_id: UUID | None = None, limit: int = 20, page_token: UUID | None = None, order: str = "desc", diff --git a/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/28725d931ca5_.py b/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/28725d931ca5_.py new file mode 100644 index 000000000..fbe11c567 --- /dev/null +++ b/apps/beeai-server/src/beeai_server/infrastructure/persistence/migrations/alembic/versions/28725d931ca5_.py @@ -0,0 +1,37 @@ +# Copyright 2025 © BeeAI a Series of LF Projects, LLC +# SPDX-License-Identifier: Apache-2.0 + +"""add optional provider_id relationship to context + +Revision ID: 28725d931ca5 +Revises: c69e8ad9a82a +Create Date: 2025-10-22 13:40:10.921827 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "28725d931ca5" +down_revision: str | None = "c69e8ad9a82a" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("contexts", sa.Column("provider_id", sa.UUID(), nullable=True)) + op.create_foreign_key(None, "contexts", "providers", ["provider_id"], ["id"], ondelete="CASCADE") + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade schema.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("contexts_provider_id_fkey", "contexts", type_="foreignkey") + op.drop_column("contexts", "provider_id") + # ### end Alembic commands ### diff --git a/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/context.py b/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/context.py index 10ec77af9..c6390e8f5 100644 --- a/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/context.py +++ b/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/context.py @@ -38,6 +38,7 @@ Column("updated_at", DateTime(timezone=True), nullable=False), Column("last_active_at", DateTime(timezone=True), nullable=True), Column("created_by", ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + Column("provider_id", ForeignKey("providers.id", ondelete="CASCADE"), nullable=True), Column("metadata", JSON, nullable=True), ) @@ -74,6 +75,7 @@ async def list_paginated( self, *, user_id: UUID | None = None, + provider_id: UUID | None = None, limit: int = 20, page_token: UUID | None = None, order: str = "desc", @@ -83,6 +85,8 @@ async def list_paginated( query = contexts_table.select() if user_id is not None: query = query.where(contexts_table.c.created_by == user_id) + if provider_id is not None: + query = query.where(contexts_table.c.provider_id == provider_id) if not include_empty: # Use EXISTS subquery to find contexts that have at least one history record subquery = select(context_history_table.c.context_id).where( @@ -113,6 +117,7 @@ async def create(self, *, context: Context) -> None: updated_at=context.updated_at, last_active_at=context.last_active_at, created_by=context.created_by, + provider_id=context.provider_id, metadata=context.metadata, ) await self._connection.execute(query) @@ -219,5 +224,6 @@ def _row_to_context(self, row: Row) -> Context: updated_at=row.updated_at, last_active_at=row.last_active_at, created_by=row.created_by, + provider_id=row.provider_id, metadata=row.metadata, ) diff --git a/apps/beeai-server/src/beeai_server/service_layer/services/contexts.py b/apps/beeai-server/src/beeai_server/service_layer/services/contexts.py index 1d1c833b0..85dc97bfc 100644 --- a/apps/beeai-server/src/beeai_server/service_layer/services/contexts.py +++ b/apps/beeai-server/src/beeai_server/service_layer/services/contexts.py @@ -33,8 +33,8 @@ def __init__(self, uow: IUnitOfWorkFactory, configuration: Configuration, object self._configuration = configuration self._expire_resources_after = timedelta(days=configuration.context.resource_expire_after_days) - async def create(self, *, user: User, metadata: Metadata) -> Context: - context = Context(created_by=user.id, metadata=metadata) + async def create(self, *, user: User, metadata: Metadata, provider_id: UUID | None = None) -> Context: + context = Context(created_by=user.id, metadata=metadata, provider_id=provider_id) async with self._uow() as uow: await uow.contexts.create(context=context) await uow.commit() @@ -45,11 +45,12 @@ async def get(self, *, context_id: UUID, user: User) -> Context: return await uow.contexts.get(context_id=context_id, user_id=user.id) async def list( - self, *, user: User, pagination: PaginationQuery, include_empty: bool = True + self, *, user: User, pagination: PaginationQuery, include_empty: bool = True, provider_id: UUID | None = None ) -> PaginatedResult[Context]: async with self._uow() as uow: return await uow.contexts.list_paginated( user_id=user.id, + provider_id=provider_id, limit=pagination.limit, page_token=pagination.page_token, order=pagination.order, diff --git a/apps/beeai-server/tests/e2e/routes/test_contexts.py b/apps/beeai-server/tests/e2e/routes/test_contexts.py index 7ec0596e2..18c74fb3b 100644 --- a/apps/beeai-server/tests/e2e/routes/test_contexts.py +++ b/apps/beeai-server/tests/e2e/routes/test_contexts.py @@ -184,3 +184,88 @@ async def test_context_update_and_patch(subtests): with subtests.test("exceed metadata size"), pytest.raises(HTTPStatusError): await context.patch_metadata(metadata={str(i): str(i) for i in range(15)}) + + +@pytest.mark.usefixtures("clean_up", "setup_platform_client") +async def test_context_provider_filtering(subtests): + """Test creating contexts with provider_id and filtering by provider_id.""" + from a2a.types import AgentCapabilities, AgentCard + from beeai_sdk.platform import Provider + + provider1 = None + provider2 = None + context_with_provider1 = None + context_with_provider2 = None + context_without_provider = None + + with subtests.test("create dummy providers"): + # Create first dummy provider with network URL + agent_card1 = AgentCard( + name="Test Provider 1", + description="First test provider", + url="http://localhost:9001/", + version="1.0.0", + default_input_modes=["text"], + default_output_modes=["text"], + capabilities=AgentCapabilities(), + skills=[], + ) + provider1 = await Provider.create(location="http://localhost:9001", agent_card=agent_card1) + assert provider1.id is not None + + # Create second dummy provider with network URL + agent_card2 = AgentCard( + name="Test Provider 2", + description="Second test provider", + url="http://localhost:9002/", + version="1.0.0", + default_input_modes=["text"], + default_output_modes=["text"], + capabilities=AgentCapabilities(), + skills=[], + ) + provider2 = await Provider.create(location="http://localhost:9002", agent_card=agent_card2) + assert provider2.id is not None + assert provider1.id != provider2.id + + with subtests.test("create contexts with and without provider_id"): + # Create context associated with provider1 + context_with_provider1 = await Context.create(metadata={"name": "context_provider1"}, provider_id=provider1.id) + assert context_with_provider1.provider_id == provider1.id + + # Create context associated with provider2 + context_with_provider2 = await Context.create(metadata={"name": "context_provider2"}, provider_id=provider2.id) + assert context_with_provider2.provider_id == provider2.id + + # Create context without provider + context_without_provider = await Context.create(metadata={"name": "context_no_provider"}) + assert context_without_provider.provider_id is None + + with subtests.test("list all contexts without filter"): + all_contexts = await Context.list() + assert len(all_contexts.items) == 3 + context_ids = [ctx.id for ctx in all_contexts.items] + assert context_with_provider1.id in context_ids + assert context_with_provider2.id in context_ids + assert context_without_provider.id in context_ids + + with subtests.test("filter contexts by provider1"): + provider1_contexts = await Context.list(provider_id=provider1.id) + assert len(provider1_contexts.items) == 1 + assert provider1_contexts.items[0].id == context_with_provider1.id + assert provider1_contexts.items[0].provider_id == provider1.id + + with subtests.test("filter contexts by provider2"): + provider2_contexts = await Context.list(provider_id=provider2.id) + assert len(provider2_contexts.items) == 1 + assert provider2_contexts.items[0].id == context_with_provider2.id + assert provider2_contexts.items[0].provider_id == provider2.id + + with subtests.test("filter by non-existent provider returns empty list"): + nonexistent_provider_id = str(uuid.uuid4()) + no_contexts = await Context.list(provider_id=nonexistent_provider_id) + assert len(no_contexts.items) == 0 + + with subtests.test("get context includes provider_id"): + fetched_context = await Context.get(context_with_provider1.id) + assert fetched_context.provider_id == provider1.id