Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions apps/beeai-sdk-py/src/beeai_sdk/platform/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,11 @@ async def list_variables(self: "Provider | str", *, client: PlatformClient | Non
return result.raise_for_status().json()["variables"]

@staticmethod
async def list(*, origin: str | None = None, client: PlatformClient | None = None) -> list["Provider"]:
async def list(
*, origin: str | None = None, user_owned: bool | None = None, client: PlatformClient | None = None
) -> list["Provider"]:
async with client or get_platform_client() as client:
params = {"origin": origin} if origin else None
params = filter_dict({"origin": origin, "user_owned": user_owned})
return pydantic.TypeAdapter(list[Provider]).validate_python(
(
await client.get(
Expand Down
9 changes: 8 additions & 1 deletion apps/beeai-sdk-py/src/beeai_sdk/platform/provider_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ async def list(
limit: int | None = None,
order: Literal["asc"] | Literal["desc"] | None = "asc",
order_by: Literal["created_at"] | Literal["updated_at"] | None = None,
user_owned: bool | None = None,
client: PlatformClient | None = None,
) -> PaginatedResult[ProviderBuild]:
# `self` has a weird type so that you can call both `instance.list_history()` or `ProviderBuild.list_history("123")`
Expand All @@ -174,7 +175,13 @@ async def list(
await platform_client.get(
url="/api/v1/provider_builds",
params=filter_dict(
{"page_token": page_token, "limit": limit, "order": order, "order_by": order_by}
{
"page_token": page_token,
"limit": limit,
"order": order,
"order_by": order_by,
"user_owned": user_owned,
}
),
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ async def list_provider_builds(
return await provider_build_service.list_builds(
pagination=query,
status=query.status,
user=user.user if query.user_owned else None,
user=user.user,
user_owned=query.user_owned,
)

@router.get("/{id}/logs")
Expand Down
4 changes: 2 additions & 2 deletions apps/beeai-server/src/beeai_server/api/routes/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ async def list_providers(
provider_service: ProviderServiceDependency,
request: Request,
user: Annotated[AuthorizedUser, Depends(RequiresPermissions(providers={"read"}), use_cache=False)],
user_owned: Annotated[bool, Query()] = False,
user_owned: Annotated[bool | None, Query()] = None,
origin: Annotated[str | None, Query()] = None,
) -> PaginatedResult[EntityModel[ProviderWithState]]:
providers = []
for provider in await provider_service.list_providers(user=user.user if user_owned else None, origin=origin):
for provider in await provider_service.list_providers(user=user.user, user_owned=user_owned, origin=origin):
new_provider = provider.model_copy(
update={
"agent_card": create_proxy_agent_card(provider.agent_card, provider_id=provider.id, request=request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ class CreateProviderBuildRequest(BaseModel):

class ProviderBuildListQuery(PaginationQuery):
status: BuildState | None = None
user_owned: bool = False
user_owned: bool | None = None
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
@runtime_checkable
class IProviderRepository(Protocol):
async def list(
self, *, type: ProviderType | None = None, user_id: UUID | None = None, origin: str | None = None
self,
*,
type: ProviderType | None = None,
user_id: UUID | None = None,
exclude_user_id: UUID | None = None,
origin: str | None = None,
) -> AsyncIterator[Provider]:
yield ... # type: ignore

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ async def list_paginated(
order_by: str = "created_at",
status: BuildState | None = None,
user_id: UUID | None = None,
exclude_user_id: UUID | None = None,
) -> PaginatedResult[ProviderBuild]: ...

async def create(self, *, provider_build: ProviderBuild) -> None: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,18 @@ async def delete(self, *, provider_id: UUID, user_id: UUID | None = None) -> int
return result.rowcount

async def list(
self, *, type: ProviderType | None = None, user_id: UUID | None = None, origin: str | None = None
self,
*,
type: ProviderType | None = None,
user_id: UUID | None = None,
exclude_user_id: UUID | None = None,
origin: str | None = None,
) -> AsyncIterator[Provider]:
query = providers_table.select()
if user_id is not None:
query = query.where(providers_table.c.created_by == user_id)
if exclude_user_id is not None:
query = query.where(providers_table.c.created_by != exclude_user_id)
if origin is not None:
query = query.where(providers_table.c.origin == origin)
if type is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,13 @@ async def list_paginated(
order_by: str = "created_at",
status: BuildState | None = None,
user_id: UUID | None = None,
exclude_user_id: UUID | None = None,
) -> PaginatedResult[ProviderBuild]:
query = provider_builds_table.select()
if user_id:
query = query.where(provider_builds_table.c.created_by == user_id)
if exclude_user_id:
query = query.where(provider_builds_table.c.created_by != exclude_user_id)
if status is not None:
query = query.where(provider_builds_table.c.status == status)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,23 @@ async def get_build(self, provider_build_id: UUID) -> ProviderBuild:
return await uow.provider_builds.get(provider_build_id=provider_build_id)

async def list_builds(
self, pagination: PaginationQuery, status: BuildState | None = None, user: User | None = None
self,
pagination: PaginationQuery,
status: BuildState | None = None,
user: User | None = None,
user_owned: bool | None = None,
) -> PaginatedResult[ProviderBuild]:
user_id = user.id if user else None
# user_owned: True -> show user owned entities
# user_owned: False -> show all but user owned entities
# user_owned: None -> show all entities

if user_owned is not None and user is None:
raise ValueError("user_owned cannot be specified without a user")

async with self._uow() as uow:
return await uow.provider_builds.list_paginated(
user_id=user_id,
user_id=user.id if user_owned is True else None,
exclude_user_id=user.id if user_owned is False else None,
limit=pagination.limit,
page_token=pagination.page_token,
order=pagination.order,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,26 @@ async def remove_orphaned_providers(self):
existing_providers = [p.id async for p in uow.providers.list()]
await self._deployment_manager.remove_orphaned_providers(existing_providers=existing_providers)

async def list_providers(self, user: User | None = None, origin: str | None = None) -> list[ProviderWithState]:
user_id = user.id if user else None
async def list_providers(
self, user: User | None = None, user_owned: bool | None = None, origin: str | None = None
) -> list[ProviderWithState]:
# user_owned: True -> show user owned entities
# user_owned: False -> show all but user owned entities
# user_owned: None -> show all entities

if user_owned is not None and user is None:
raise ValueError("user_owned cannot be specified without a user")

async with self._uow() as uow:
return await self._get_providers_with_state(
providers=[p async for p in uow.providers.list(user_id=user_id, origin=origin)]
providers=[
p
async for p in uow.providers.list(
user_id=user.id if user_owned is True else None,
exclude_user_id=user.id if user_owned is False else None,
origin=origin,
)
]
)

async def get_provider(
Expand Down
15 changes: 15 additions & 0 deletions apps/beeai-server/tests/e2e/routes/test_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ async def test_provider_crud(subtests, test_configuration):
provider = await provider.patch(variables={})
assert await provider.list_variables() == {}

with subtests.test("test user_owned filtering"):
# Test user_owned=True (should see exactly 1 provider - admin's)
admin_providers = await Provider.list(user_owned=True)
assert len(admin_providers) == 1
assert admin_providers[0].id == provider.id

# Test user_owned=False (should see 0 providers - no other users' providers)
others_providers = await Provider.list(user_owned=False)
assert len(others_providers) == 0

# Test user_owned=None (should see exactly 1 provider - all providers)
all_providers = await Provider.list(user_owned=None)
assert len(all_providers) == 1
assert all_providers[0].id == provider.id

with subtests.test("delete provider"):
await provider.delete()
with pytest.raises(HTTPError, match="404 Not Found"):
Expand Down