diff --git a/apps/beeai-sdk-py/src/beeai_sdk/platform/provider.py b/apps/beeai-sdk-py/src/beeai_sdk/platform/provider.py index 78bbed0ce..7e8452915 100644 --- a/apps/beeai-sdk-py/src/beeai_sdk/platform/provider.py +++ b/apps/beeai-sdk-py/src/beeai_sdk/platform/provider.py @@ -202,9 +202,11 @@ async def list_variables(self: "Provider | str", *, client: PlatformClient | Non return result.raise_for_status().json()["variables"] @staticmethod - async def list(*, origin: str | None = None, client: PlatformClient | None = None) -> list["Provider"]: + async def list( + *, origin: str | None = None, user_owned: bool | None = None, client: PlatformClient | None = None + ) -> list["Provider"]: async with client or get_platform_client() as client: - params = {"origin": origin} if origin else None + params = filter_dict({"origin": origin, "user_owned": user_owned}) return pydantic.TypeAdapter(list[Provider]).validate_python( ( await client.get( diff --git a/apps/beeai-sdk-py/src/beeai_sdk/platform/provider_build.py b/apps/beeai-sdk-py/src/beeai_sdk/platform/provider_build.py index 26d1cf2e0..d9b2c2e71 100644 --- a/apps/beeai-sdk-py/src/beeai_sdk/platform/provider_build.py +++ b/apps/beeai-sdk-py/src/beeai_sdk/platform/provider_build.py @@ -165,6 +165,7 @@ async def list( limit: int | None = None, order: Literal["asc"] | Literal["desc"] | None = "asc", order_by: Literal["created_at"] | Literal["updated_at"] | None = None, + user_owned: bool | None = None, client: PlatformClient | None = None, ) -> PaginatedResult[ProviderBuild]: # `self` has a weird type so that you can call both `instance.list_history()` or `ProviderBuild.list_history("123")` @@ -174,7 +175,13 @@ async def list( await platform_client.get( url="/api/v1/provider_builds", params=filter_dict( - {"page_token": page_token, "limit": limit, "order": order, "order_by": order_by} + { + "page_token": page_token, + "limit": limit, + "order": order, + "order_by": order_by, + "user_owned": user_owned, + } ), ) ) diff --git a/apps/beeai-server/src/beeai_server/api/routes/provider_builds.py b/apps/beeai-server/src/beeai_server/api/routes/provider_builds.py index 545ece025..478a07d72 100644 --- a/apps/beeai-server/src/beeai_server/api/routes/provider_builds.py +++ b/apps/beeai-server/src/beeai_server/api/routes/provider_builds.py @@ -63,7 +63,8 @@ async def list_provider_builds( return await provider_build_service.list_builds( pagination=query, status=query.status, - user=user.user if query.user_owned else None, + user=user.user, + user_owned=query.user_owned, ) @router.get("/{id}/logs") diff --git a/apps/beeai-server/src/beeai_server/api/routes/providers.py b/apps/beeai-server/src/beeai_server/api/routes/providers.py index 9a45abbab..1432ee45d 100644 --- a/apps/beeai-server/src/beeai_server/api/routes/providers.py +++ b/apps/beeai-server/src/beeai_server/api/routes/providers.py @@ -76,11 +76,11 @@ async def list_providers( provider_service: ProviderServiceDependency, request: Request, user: Annotated[AuthorizedUser, Depends(RequiresPermissions(providers={"read"}), use_cache=False)], - user_owned: Annotated[bool, Query()] = False, + user_owned: Annotated[bool | None, Query()] = None, origin: Annotated[str | None, Query()] = None, ) -> PaginatedResult[EntityModel[ProviderWithState]]: providers = [] - for provider in await provider_service.list_providers(user=user.user if user_owned else None, origin=origin): + for provider in await provider_service.list_providers(user=user.user, user_owned=user_owned, origin=origin): new_provider = provider.model_copy( update={ "agent_card": create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request) diff --git a/apps/beeai-server/src/beeai_server/api/schema/provider_build.py b/apps/beeai-server/src/beeai_server/api/schema/provider_build.py index a6845af03..ec06b9f14 100644 --- a/apps/beeai-server/src/beeai_server/api/schema/provider_build.py +++ b/apps/beeai-server/src/beeai_server/api/schema/provider_build.py @@ -16,4 +16,4 @@ class CreateProviderBuildRequest(BaseModel): class ProviderBuildListQuery(PaginationQuery): status: BuildState | None = None - user_owned: bool = False + user_owned: bool | None = None diff --git a/apps/beeai-server/src/beeai_server/domain/repositories/provider.py b/apps/beeai-server/src/beeai_server/domain/repositories/provider.py index fac3f16b8..9a5d5f6bb 100644 --- a/apps/beeai-server/src/beeai_server/domain/repositories/provider.py +++ b/apps/beeai-server/src/beeai_server/domain/repositories/provider.py @@ -11,7 +11,12 @@ @runtime_checkable class IProviderRepository(Protocol): async def list( - self, *, type: ProviderType | None = None, user_id: UUID | None = None, origin: str | None = None + self, + *, + type: ProviderType | None = None, + user_id: UUID | None = None, + exclude_user_id: UUID | None = None, + origin: str | None = None, ) -> AsyncIterator[Provider]: yield ... # type: ignore diff --git a/apps/beeai-server/src/beeai_server/domain/repositories/provider_build.py b/apps/beeai-server/src/beeai_server/domain/repositories/provider_build.py index ac352825c..1d2a71606 100644 --- a/apps/beeai-server/src/beeai_server/domain/repositories/provider_build.py +++ b/apps/beeai-server/src/beeai_server/domain/repositories/provider_build.py @@ -25,6 +25,7 @@ async def list_paginated( order_by: str = "created_at", status: BuildState | None = None, user_id: UUID | None = None, + exclude_user_id: UUID | None = None, ) -> PaginatedResult[ProviderBuild]: ... async def create(self, *, provider_build: ProviderBuild) -> None: ... diff --git a/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/provider.py b/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/provider.py index 3a8c62961..9ad51b83b 100644 --- a/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/provider.py +++ b/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/provider.py @@ -118,11 +118,18 @@ async def delete(self, *, provider_id: UUID, user_id: UUID | None = None) -> int return result.rowcount async def list( - self, *, type: ProviderType | None = None, user_id: UUID | None = None, origin: str | None = None + self, + *, + type: ProviderType | None = None, + user_id: UUID | None = None, + exclude_user_id: UUID | None = None, + origin: str | None = None, ) -> AsyncIterator[Provider]: query = providers_table.select() if user_id is not None: query = query.where(providers_table.c.created_by == user_id) + if exclude_user_id is not None: + query = query.where(providers_table.c.created_by != exclude_user_id) if origin is not None: query = query.where(providers_table.c.origin == origin) if type is not None: diff --git a/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/provider_build.py b/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/provider_build.py index 934bbaf84..f42400f22 100644 --- a/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/provider_build.py +++ b/apps/beeai-server/src/beeai_server/infrastructure/persistence/repositories/provider_build.py @@ -122,10 +122,13 @@ async def list_paginated( order_by: str = "created_at", status: BuildState | None = None, user_id: UUID | None = None, + exclude_user_id: UUID | None = None, ) -> PaginatedResult[ProviderBuild]: query = provider_builds_table.select() if user_id: query = query.where(provider_builds_table.c.created_by == user_id) + if exclude_user_id: + query = query.where(provider_builds_table.c.created_by != exclude_user_id) if status is not None: query = query.where(provider_builds_table.c.status == status) diff --git a/apps/beeai-server/src/beeai_server/service_layer/services/provider_build.py b/apps/beeai-server/src/beeai_server/service_layer/services/provider_build.py index 1d479421b..087cd0d5c 100644 --- a/apps/beeai-server/src/beeai_server/service_layer/services/provider_build.py +++ b/apps/beeai-server/src/beeai_server/service_layer/services/provider_build.py @@ -119,12 +119,23 @@ async def get_build(self, provider_build_id: UUID) -> ProviderBuild: return await uow.provider_builds.get(provider_build_id=provider_build_id) async def list_builds( - self, pagination: PaginationQuery, status: BuildState | None = None, user: User | None = None + self, + pagination: PaginationQuery, + status: BuildState | None = None, + user: User | None = None, + user_owned: bool | None = None, ) -> PaginatedResult[ProviderBuild]: - user_id = user.id if user else None + # user_owned: True -> show user owned entities + # user_owned: False -> show all but user owned entities + # user_owned: None -> show all entities + + if user_owned is not None and user is None: + raise ValueError("user_owned cannot be specified without a user") + async with self._uow() as uow: return await uow.provider_builds.list_paginated( - user_id=user_id, + user_id=user.id if user_owned is True else None, + exclude_user_id=user.id if user_owned is False else None, limit=pagination.limit, page_token=pagination.page_token, order=pagination.order, diff --git a/apps/beeai-server/src/beeai_server/service_layer/services/providers.py b/apps/beeai-server/src/beeai_server/service_layer/services/providers.py index e94f5b99c..a738b2ac7 100644 --- a/apps/beeai-server/src/beeai_server/service_layer/services/providers.py +++ b/apps/beeai-server/src/beeai_server/service_layer/services/providers.py @@ -259,11 +259,26 @@ async def remove_orphaned_providers(self): existing_providers = [p.id async for p in uow.providers.list()] await self._deployment_manager.remove_orphaned_providers(existing_providers=existing_providers) - async def list_providers(self, user: User | None = None, origin: str | None = None) -> list[ProviderWithState]: - user_id = user.id if user else None + async def list_providers( + self, user: User | None = None, user_owned: bool | None = None, origin: str | None = None + ) -> list[ProviderWithState]: + # user_owned: True -> show user owned entities + # user_owned: False -> show all but user owned entities + # user_owned: None -> show all entities + + if user_owned is not None and user is None: + raise ValueError("user_owned cannot be specified without a user") + async with self._uow() as uow: return await self._get_providers_with_state( - providers=[p async for p in uow.providers.list(user_id=user_id, origin=origin)] + providers=[ + p + async for p in uow.providers.list( + user_id=user.id if user_owned is True else None, + exclude_user_id=user.id if user_owned is False else None, + origin=origin, + ) + ] ) async def get_provider( diff --git a/apps/beeai-server/tests/e2e/routes/test_providers.py b/apps/beeai-server/tests/e2e/routes/test_providers.py index 1eb256049..fd17cf9e5 100644 --- a/apps/beeai-server/tests/e2e/routes/test_providers.py +++ b/apps/beeai-server/tests/e2e/routes/test_providers.py @@ -44,6 +44,21 @@ async def test_provider_crud(subtests, test_configuration): provider = await provider.patch(variables={}) assert await provider.list_variables() == {} + with subtests.test("test user_owned filtering"): + # Test user_owned=True (should see exactly 1 provider - admin's) + admin_providers = await Provider.list(user_owned=True) + assert len(admin_providers) == 1 + assert admin_providers[0].id == provider.id + + # Test user_owned=False (should see 0 providers - no other users' providers) + others_providers = await Provider.list(user_owned=False) + assert len(others_providers) == 0 + + # Test user_owned=None (should see exactly 1 provider - all providers) + all_providers = await Provider.list(user_owned=None) + assert len(all_providers) == 1 + assert all_providers[0].id == provider.id + with subtests.test("delete provider"): await provider.delete() with pytest.raises(HTTPError, match="404 Not Found"):