From ef064c644698c850d9f24dba482224a32c5c6275 Mon Sep 17 00:00:00 2001 From: Ralf Grubenmann Date: Thu, 8 Aug 2024 16:40:46 +0200 Subject: [PATCH 1/5] refactor: add custom ULID type for sqlalchemy --- components/renku_data_services/project/db.py | 48 +++++++++++-------- components/renku_data_services/project/orm.py | 3 +- components/renku_data_services/session/orm.py | 2 +- components/renku_data_services/utils/py.typed | 0 .../renku_data_services/utils/sqlalchemy.py | 25 ++++++++++ 5 files changed, 55 insertions(+), 23 deletions(-) create mode 100644 components/renku_data_services/utils/py.typed create mode 100644 components/renku_data_services/utils/sqlalchemy.py diff --git a/components/renku_data_services/project/db.py b/components/renku_data_services/project/db.py index 42e8ffee0..9613375ea 100644 --- a/components/renku_data_services/project/db.py +++ b/components/renku_data_services/project/db.py @@ -10,6 +10,7 @@ from sqlalchemy import Select, delete, func, select, update from sqlalchemy.ext.asyncio import AsyncSession +from ulid import ULID import renku_data_services.base_models as base_models from renku_data_services import errors @@ -73,7 +74,7 @@ async def get_projects( total_elements = results[1].scalar() or 0 return [p.dump() for p in projects_orm], total_elements - async def get_project(self, user: base_models.APIUser, project_id: str) -> models.Project: + async def get_project(self, user: base_models.APIUser, project_id: ULID) -> models.Project: """Get one project from the database.""" authorized = await self.authz.has_permission(user, ResourceType.project, project_id, Scope.READ) if not authorized: @@ -82,7 +83,7 @@ async def get_project(self, user: base_models.APIUser, project_id: str) -> model ) async with self.session_maker() as session: - stmt = select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id) + stmt = select(schemas.ProjectORM).where(schemas.ProjectORM.id == str(project_id)) result = await session.execute(stmt) project_orm = result.scalars().first() @@ -110,7 +111,10 @@ async def get_project_by_namespace_slug( raise errors.MissingResourceError(message=not_found_msg) authorized = await self.authz.has_permission( - user=user, resource_type=ResourceType.project, resource_id=project_orm.id, scope=Scope.READ + user=user, + resource_type=ResourceType.project, + resource_id=project_orm.id, + scope=Scope.READ, ) if not authorized: raise errors.MissingResourceError(message=not_found_msg) @@ -167,7 +171,7 @@ async def insert_project( creation_date=datetime.now(UTC).replace(microsecond=0), keywords=project.keywords, ) - project_slug = schemas.ProjectSlug(slug, project_id=project_orm.id, namespace_id=ns.id) + project_slug = schemas.ProjectSlug(slug, project_id=str(project_orm.id), namespace_id=ns.id) session.add(project_slug) session.add(project_orm) @@ -182,19 +186,20 @@ async def insert_project( async def update_project( self, user: base_models.APIUser, - project_id: str, + project_id: ULID, payload: dict[str, Any], etag: str | None = None, *, session: AsyncSession | None = None, ) -> models.ProjectUpdate: """Update a project entry.""" + project_id_str: str = str(project_id) if not session: raise errors.ProgrammingError(message="A database session is required") - result = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) + result = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str)) project = result.one_or_none() if project is None: - raise errors.MissingResourceError(message=f"The project with id '{project_id}' cannot be found") + raise errors.MissingResourceError(message=f"The project with id '{project_id_str}' cannot be found") old_project = project.dump() required_scope = Scope.WRITE @@ -210,7 +215,7 @@ async def update_project( authorized = await self.authz.has_permission(user, ResourceType.project, project_id, required_scope) if not authorized: raise errors.MissingResourceError( - message=f"Project with id '{project_id}' does not exist or you do not have access to it." + message=f"Project with id '{project_id_str}' does not exist or you do not have access to it." ) current_etag = project.dump().etag @@ -219,11 +224,11 @@ async def update_project( if "repositories" in payload: payload["repositories"] = [ - schemas.ProjectRepositoryORM(url=r, project_id=project_id, project=project) + schemas.ProjectRepositoryORM(url=r, project_id=project_id_str, project=project) for r in payload["repositories"] ] # Trigger update for ``updated_at`` column - await session.execute(update(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id).values()) + await session.execute(update(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str).values()) if "keywords" in payload and not payload["keywords"]: payload["keywords"] = None @@ -263,7 +268,7 @@ async def update_project( @Authz.authz_change(AuthzOperation.delete, ResourceType.project) @dispatch_message(avro_schema_v2.ProjectRemoved) async def delete_project( - self, user: base_models.APIUser, project_id: str, *, session: AsyncSession | None = None + self, user: base_models.APIUser, project_id: ULID, *, session: AsyncSession | None = None ) -> models.Project | None: """Delete a project.""" if not session: @@ -274,16 +279,17 @@ async def delete_project( message=f"Project with id '{project_id}' does not exist or you do not have access to it." ) - result = await session.execute(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) + project_id_str = str(project_id) + result = await session.execute(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str)) project = result.scalar_one_or_none() if project is None: return None - await session.execute(delete(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) + await session.execute(delete(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str)) await session.execute( - delete(storage_schemas.CloudStorageORM).where(storage_schemas.CloudStorageORM.project_id == project_id) + delete(storage_schemas.CloudStorageORM).where(storage_schemas.CloudStorageORM.project_id == project_id_str) ) return project.dump() @@ -303,15 +309,15 @@ def _filter_by_namespace_slug(statement: Select[tuple[_T]], namespace: str) -> S def _project_exists( - f: Callable[Concatenate[ProjectMemberRepository, base_models.APIUser, str, _P], Awaitable[_T]], -) -> Callable[Concatenate[ProjectMemberRepository, base_models.APIUser, str, _P], Awaitable[_T]]: + f: Callable[Concatenate[ProjectMemberRepository, base_models.APIUser, ULID, _P], Awaitable[_T]], +) -> Callable[Concatenate[ProjectMemberRepository, base_models.APIUser, ULID, _P], Awaitable[_T]]: """Checks if the project exists when adding or modifying project members.""" @functools.wraps(f) async def decorated_func( self: ProjectMemberRepository, user: base_models.APIUser, - project_id: str, + project_id: ULID, *args: _P.args, **kwargs: _P.kwargs, ) -> _T: @@ -321,7 +327,7 @@ async def decorated_func( message="The decorator that checks if a project exists requires a database session in the " f"keyword arguments, but instead it got {type(session)}" ) - stmt = select(schemas.ProjectORM.id).where(schemas.ProjectORM.id == project_id) + stmt = select(schemas.ProjectORM.id).where(schemas.ProjectORM.id == str(project_id)) res = await session.scalar(stmt) if not res: raise errors.MissingResourceError( @@ -350,7 +356,7 @@ def __init__( @with_db_transaction @_project_exists async def get_members( - self, user: base_models.APIUser, project_id: str, *, session: AsyncSession | None = None + self, user: base_models.APIUser, project_id: ULID, *, session: AsyncSession | None = None ) -> list[Member]: """Get all members of a project.""" members = await self.authz.members(user, ResourceType.project, project_id) @@ -363,7 +369,7 @@ async def get_members( async def update_members( self, user: base_models.APIUser, - project_id: str, + project_id: ULID, members: list[Member], *, session: AsyncSession | None = None, @@ -392,7 +398,7 @@ async def update_members( @_project_exists @dispatch_message(events.ProjectMembershipChanged) async def delete_members( - self, user: base_models.APIUser, project_id: str, user_ids: list[str], *, session: AsyncSession | None = None + self, user: base_models.APIUser, project_id: ULID, user_ids: list[str], *, session: AsyncSession | None = None ) -> list[MembershipChange]: """Delete members from a project.""" if len(user_ids) == 0: diff --git a/components/renku_data_services/project/orm.py b/components/renku_data_services/project/orm.py index db6e60c50..18d2b0708 100644 --- a/components/renku_data_services/project/orm.py +++ b/components/renku_data_services/project/orm.py @@ -13,6 +13,7 @@ from renku_data_services.namespace.orm import NamespaceORM from renku_data_services.project import models from renku_data_services.project.apispec import Visibility +from renku_data_services.utils.sqlalchemy import ULIDType metadata_obj = MetaData(schema="projects") # Has to match alembic ini section name @@ -27,7 +28,7 @@ class ProjectORM(BaseORM): """A Renku native project.""" __tablename__ = "projects" - id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False) + id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False) name: Mapped[str] = mapped_column("name", String(99)) visibility: Mapped[Visibility] created_by_id: Mapped[str] = mapped_column("created_by_id", String()) diff --git a/components/renku_data_services/session/orm.py b/components/renku_data_services/session/orm.py index 697acaefd..0620b4309 100644 --- a/components/renku_data_services/session/orm.py +++ b/components/renku_data_services/session/orm.py @@ -104,7 +104,7 @@ class SessionLauncherORM(BaseORM): project: Mapped[ProjectORM] = relationship(init=False) environment: Mapped[EnvironmentORM | None] = relationship(init=False) - project_id: Mapped[str] = mapped_column( + project_id: Mapped[ULID] = mapped_column( "project_id", ForeignKey(ProjectORM.id, ondelete="CASCADE"), default=None, index=True ) """Id of the project this session belongs to.""" diff --git a/components/renku_data_services/utils/py.typed b/components/renku_data_services/utils/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/components/renku_data_services/utils/sqlalchemy.py b/components/renku_data_services/utils/sqlalchemy.py new file mode 100644 index 000000000..ff4804bf3 --- /dev/null +++ b/components/renku_data_services/utils/sqlalchemy.py @@ -0,0 +1,25 @@ +"""Utilities for SQLAlchemy.""" + +from typing import cast + +from sqlalchemy import Dialect, types +from ulid import ULID + + +class ULIDType(types.TypeDecorator): + """Wrapper type for ULID <--> str conversion.""" + + impl = types.String + cache_ok = True + + def process_bind_param(self, value: ULID | None, dialect: Dialect) -> str | None: + """Transform value for storing in the database.""" + if value is None: + return None + return str(value) + + def process_result_value(self, value: str | None, dialect: Dialect) -> ULID | None: + """Transform string from database into ULID.""" + if value is None: + return None + return cast(ULID, ULID.from_str(value)) From 1bde9d1c1fb2fcafbfb6dd195defdc337033df7d Mon Sep 17 00:00:00 2001 From: Ralf Grubenmann Date: Fri, 9 Aug 2024 14:51:14 +0200 Subject: [PATCH 2/5] address comments --- .../renku_data_services/project/blueprints.py | 13 +++++++------ components/renku_data_services/project/db.py | 14 +++++++------- components/renku_data_services/project/orm.py | 2 +- components/renku_data_services/session/db.py | 8 +++++--- components/renku_data_services/session/orm.py | 4 ++-- components/renku_data_services/utils/sqlalchemy.py | 2 +- 6 files changed, 23 insertions(+), 20 deletions(-) diff --git a/components/renku_data_services/project/blueprints.py b/components/renku_data_services/project/blueprints.py index a9261ff00..2837eacaa 100644 --- a/components/renku_data_services/project/blueprints.py +++ b/components/renku_data_services/project/blueprints.py @@ -6,6 +6,7 @@ from sanic import HTTPResponse, Request, json from sanic.response import JSONResponse from sanic_ext import validate +from ulid import ULID import renku_data_services.base_models as base_models from renku_data_services.authz.models import Member, Role, Visibility @@ -110,7 +111,7 @@ def get_one(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @validate_path_project_id async def _get_one(request: Request, user: base_models.APIUser, project_id: str) -> JSONResponse | HTTPResponse: - project = await self.project_repo.get_project(user=user, project_id=project_id) + project = await self.project_repo.get_project(user=user, project_id=ULID.from_str(project_id)) etag = request.headers.get("If-None-Match") if project.etag is not None and project.etag == etag: @@ -176,7 +177,7 @@ def delete(self) -> BlueprintFactoryResponse: @only_authenticated @validate_path_project_id async def _delete(_: Request, user: base_models.APIUser, project_id: str) -> HTTPResponse: - await self.project_repo.delete_project(user=user, project_id=project_id) + await self.project_repo.delete_project(user=user, project_id=ULID.from_str(project_id)) return HTTPResponse(status=204) return "/projects/", ["DELETE"], _delete @@ -195,7 +196,7 @@ async def _patch( body_dict = body.model_dump(exclude_none=True) project_update = await self.project_repo.update_project( - user=user, project_id=project_id, etag=etag, payload=body_dict + user=user, project_id=ULID.from_str(project_id), etag=etag, payload=body_dict ) if not isinstance(project_update, project_models.ProjectUpdate): raise errors.ProgrammingError( @@ -229,7 +230,7 @@ def get_all_members(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) @validate_path_project_id async def _get_all_members(_: Request, user: base_models.APIUser, project_id: str) -> JSONResponse: - members = await self.project_member_repo.get_members(user, project_id) + members = await self.project_member_repo.get_members(user, ULID.from_str(project_id)) users = [] @@ -261,7 +262,7 @@ def update_members(self) -> BlueprintFactoryResponse: async def _update_members(request: Request, user: base_models.APIUser, project_id: str) -> HTTPResponse: body_dump = apispec.ProjectMemberListPatchRequest.model_validate(request.json) members = [Member(Role(i.role.value), i.id, project_id) for i in body_dump.root] - await self.project_member_repo.update_members(user, project_id, members) + await self.project_member_repo.update_members(user, ULID.from_str(project_id), members) return HTTPResponse(status=200) return "/projects//members", ["PATCH"], _update_members @@ -275,7 +276,7 @@ def delete_member(self) -> BlueprintFactoryResponse: async def _delete_member( _: Request, user: base_models.APIUser, project_id: str, member_id: str ) -> HTTPResponse: - await self.project_member_repo.delete_members(user, project_id, [member_id]) + await self.project_member_repo.delete_members(user, ULID.from_str(project_id), [member_id]) return HTTPResponse(status=204) return "/projects//members/", ["DELETE"], _delete_member diff --git a/components/renku_data_services/project/db.py b/components/renku_data_services/project/db.py index 9613375ea..f5055aa7f 100644 --- a/components/renku_data_services/project/db.py +++ b/components/renku_data_services/project/db.py @@ -76,7 +76,7 @@ async def get_projects( async def get_project(self, user: base_models.APIUser, project_id: ULID) -> models.Project: """Get one project from the database.""" - authorized = await self.authz.has_permission(user, ResourceType.project, project_id, Scope.READ) + authorized = await self.authz.has_permission(user, ResourceType.project, str(project_id), Scope.READ) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." @@ -113,7 +113,7 @@ async def get_project_by_namespace_slug( authorized = await self.authz.has_permission( user=user, resource_type=ResourceType.project, - resource_id=project_orm.id, + resource_id=str(project_orm.id), scope=Scope.READ, ) if not authorized: @@ -212,7 +212,7 @@ async def update_project( if "namespace" in payload and payload["namespace"] != old_project.namespace: # NOTE: changing the namespace requires the user to be owner which means they should have DELETE permission required_scope = Scope.DELETE - authorized = await self.authz.has_permission(user, ResourceType.project, project_id, required_scope) + authorized = await self.authz.has_permission(user, ResourceType.project, str(project_id), required_scope) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id_str}' does not exist or you do not have access to it." @@ -273,7 +273,7 @@ async def delete_project( """Delete a project.""" if not session: raise errors.ProgrammingError(message="A database session is required") - authorized = await self.authz.has_permission(user, ResourceType.project, project_id, Scope.DELETE) + authorized = await self.authz.has_permission(user, ResourceType.project, str(project_id), Scope.DELETE) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." @@ -359,7 +359,7 @@ async def get_members( self, user: base_models.APIUser, project_id: ULID, *, session: AsyncSession | None = None ) -> list[Member]: """Get all members of a project.""" - members = await self.authz.members(user, ResourceType.project, project_id) + members = await self.authz.members(user, ResourceType.project, str(project_id)) members = [member for member in members if member.user_id and member.user_id != "*"] return members @@ -391,7 +391,7 @@ async def update_members( f"{requested_member_ids_set.difference(existing_member_ids)} cannot be found" ) - output = await self.authz.upsert_project_members(user, ResourceType.project, project_id, members) + output = await self.authz.upsert_project_members(user, ResourceType.project, str(project_id), members) return output @with_db_transaction @@ -404,5 +404,5 @@ async def delete_members( if len(user_ids) == 0: raise errors.ValidationError(message="Please request at least 1 member to be removed from the project") - members = await self.authz.remove_project_members(user, ResourceType.project, project_id, user_ids) + members = await self.authz.remove_project_members(user, ResourceType.project, str(project_id), user_ids) return members diff --git a/components/renku_data_services/project/orm.py b/components/renku_data_services/project/orm.py index 18d2b0708..5dcdd22b8 100644 --- a/components/renku_data_services/project/orm.py +++ b/components/renku_data_services/project/orm.py @@ -54,7 +54,7 @@ class ProjectORM(BaseORM): def dump(self) -> models.Project: """Create a project model from the ProjectORM.""" return models.Project( - id=self.id, + id=str(self.id), name=self.name, slug=self.slug.slug, namespace=self.slug.namespace.dump(), diff --git a/components/renku_data_services/session/db.py b/components/renku_data_services/session/db.py index 0d4d4078f..ca2677806 100644 --- a/components/renku_data_services/session/db.py +++ b/components/renku_data_services/session/db.py @@ -157,7 +157,9 @@ async def get_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> mo launcher = res.one_or_none() authorized = ( - await self.project_authz.has_permission(user, ResourceType.project, launcher.project_id, Scope.READ) + await self.project_authz.has_permission( + user, ResourceType.project, str(launcher.project_id), Scope.READ + ) if launcher is not None else False ) @@ -259,7 +261,7 @@ async def update_launcher( authorized = await self.project_authz.has_permission( user, ResourceType.project, - launcher.project_id, + str(launcher.project_id), Scope.WRITE, ) if not authorized: @@ -336,7 +338,7 @@ async def delete_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> authorized = await self.project_authz.has_permission( user, ResourceType.project, - launcher.project_id, + str(launcher.project_id), Scope.WRITE, ) if not authorized: diff --git a/components/renku_data_services/session/orm.py b/components/renku_data_services/session/orm.py index 0620b4309..2d5cf7dea 100644 --- a/components/renku_data_services/session/orm.py +++ b/components/renku_data_services/session/orm.py @@ -133,7 +133,7 @@ def load(cls, launcher: models.SessionLauncher) -> "SessionLauncherORM": description=launcher.description, environment_kind=launcher.environment_kind, container_image=launcher.container_image, - project_id=launcher.project_id, + project_id=ULID.from_str(launcher.project_id), environment_id=launcher.environment_id, resource_class_id=launcher.resource_class_id, default_url=launcher.default_url, @@ -143,7 +143,7 @@ def dump(self) -> models.SessionLauncher: """Create a session launcher model from the SessionLauncherORM.""" return models.SessionLauncher( id=self.id, - project_id=self.project_id, + project_id=str(self.project_id), name=self.name, created_by=models.Member(id=self.created_by_id), creation_date=self.creation_date, diff --git a/components/renku_data_services/utils/sqlalchemy.py b/components/renku_data_services/utils/sqlalchemy.py index ff4804bf3..f1cd59c9b 100644 --- a/components/renku_data_services/utils/sqlalchemy.py +++ b/components/renku_data_services/utils/sqlalchemy.py @@ -22,4 +22,4 @@ def process_result_value(self, value: str | None, dialect: Dialect) -> ULID | No """Transform string from database into ULID.""" if value is None: return None - return cast(ULID, ULID.from_str(value)) + return cast(ULID, ULID.from_str(value)) # cast because mypy doesn't understand ULID type annotations From c5999f0db68fa462d7eef1ce905688092b0b4d76 Mon Sep 17 00:00:00 2001 From: Ralf Grubenmann Date: Tue, 13 Aug 2024 15:27:28 +0200 Subject: [PATCH 3/5] fix ulid regex --- bases/renku_data_services/data_api/app.py | 2 +- .../renku_data_services/connected_services/api.spec.yaml | 2 +- .../renku_data_services/connected_services/apispec.py | 4 ++-- components/renku_data_services/crc/apispec.py | 2 +- components/renku_data_services/namespace/api.spec.yaml | 2 +- components/renku_data_services/namespace/apispec.py | 6 +++--- components/renku_data_services/notebooks/apispec.py | 2 +- components/renku_data_services/platform/apispec.py | 2 +- components/renku_data_services/project/api.spec.yaml | 2 +- components/renku_data_services/project/apispec.py | 4 ++-- components/renku_data_services/repositories/api.spec.yaml | 2 +- components/renku_data_services/repositories/apispec.py | 4 ++-- components/renku_data_services/secrets/api.spec.yaml | 2 +- components/renku_data_services/secrets/apispec.py | 4 ++-- components/renku_data_services/session/api.spec.yaml | 2 +- components/renku_data_services/storage/apispec.py | 2 +- components/renku_data_services/users/api.spec.yaml | 2 +- components/renku_data_services/users/apispec.py | 4 ++-- 18 files changed, 25 insertions(+), 25 deletions(-) diff --git a/bases/renku_data_services/data_api/app.py b/bases/renku_data_services/data_api/app.py index b2af98684..2765aa542 100644 --- a/bases/renku_data_services/data_api/app.py +++ b/bases/renku_data_services/data_api/app.py @@ -26,7 +26,7 @@ def register_all_handlers(app: Sanic, config: Config) -> Sanic: """Register all handlers on the application.""" - app.router.register_pattern("ulid", ULID.from_str, r"^[0-9A-HJKMNP-TV-Z]{26}$") + app.router.register_pattern("ulid", ULID.from_str, r"^[0-7][0-9A-HJKMNP-TV-Z]{25}$") app.router.register_pattern("renku_slug", str, r"^[a-zA-Z0-9][a-zA-Z0-9\-_.]*$") url_prefix = "/api/data" diff --git a/components/renku_data_services/connected_services/api.spec.yaml b/components/renku_data_services/connected_services/api.spec.yaml index 9eb716627..7699262b8 100644 --- a/components/renku_data_services/connected_services/api.spec.yaml +++ b/components/renku_data_services/connected_services/api.spec.yaml @@ -309,7 +309,7 @@ components: type: string minLength: 26 maxLength: 26 - pattern: "^[A-Z0-9]{26}$" # This is case-insensitive + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive ProviderId: description: ID of a OAuth2 provider, e.g. "gitlab.com". type: string diff --git a/components/renku_data_services/connected_services/apispec.py b/components/renku_data_services/connected_services/apispec.py index 884edbf5b..94d5aee7f 100644 --- a/components/renku_data_services/connected_services/apispec.py +++ b/components/renku_data_services/connected_services/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-06T05:55:35+00:00 +# timestamp: 2024-08-13T13:29:50+00:00 from __future__ import annotations @@ -153,7 +153,7 @@ class Connection(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) provider_id: str = Field( ..., diff --git a/components/renku_data_services/crc/apispec.py b/components/renku_data_services/crc/apispec.py index 5971cc232..0ece6cb91 100644 --- a/components/renku_data_services/crc/apispec.py +++ b/components/renku_data_services/crc/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-07T07:12:13+00:00 +# timestamp: 2024-08-13T13:29:45+00:00 from __future__ import annotations diff --git a/components/renku_data_services/namespace/api.spec.yaml b/components/renku_data_services/namespace/api.spec.yaml index 2646d6bb4..e7acd3490 100644 --- a/components/renku_data_services/namespace/api.spec.yaml +++ b/components/renku_data_services/namespace/api.spec.yaml @@ -365,7 +365,7 @@ components: type: string minLength: 26 maxLength: 26 - pattern: "^[A-Z0-9]{26}$" # This is case-insensitive + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive NamespaceName: description: Renku group or namespace name type: string diff --git a/components/renku_data_services/namespace/apispec.py b/components/renku_data_services/namespace/apispec.py index cb7162856..db6e5b115 100644 --- a/components/renku_data_services/namespace/apispec.py +++ b/components/renku_data_services/namespace/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-06T08:46:00+00:00 +# timestamp: 2024-08-13T13:29:48+00:00 from __future__ import annotations @@ -29,7 +29,7 @@ class NamespaceResponse(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) name: Optional[str] = Field( None, @@ -93,7 +93,7 @@ class GroupResponse(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) name: str = Field( ..., diff --git a/components/renku_data_services/notebooks/apispec.py b/components/renku_data_services/notebooks/apispec.py index ee95f91dd..e7d17b0c6 100644 --- a/components/renku_data_services/notebooks/apispec.py +++ b/components/renku_data_services/notebooks/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-06T05:55:38+00:00 +# timestamp: 2024-08-13T13:29:51+00:00 from __future__ import annotations diff --git a/components/renku_data_services/platform/apispec.py b/components/renku_data_services/platform/apispec.py index fbc6501bb..bb5c5275e 100644 --- a/components/renku_data_services/platform/apispec.py +++ b/components/renku_data_services/platform/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-06T05:55:39+00:00 +# timestamp: 2024-08-13T13:29:52+00:00 from __future__ import annotations diff --git a/components/renku_data_services/project/api.spec.yaml b/components/renku_data_services/project/api.spec.yaml index 6f97dcd78..08348b74e 100644 --- a/components/renku_data_services/project/api.spec.yaml +++ b/components/renku_data_services/project/api.spec.yaml @@ -360,7 +360,7 @@ components: type: string minLength: 26 maxLength: 26 - pattern: "^[A-Z0-9]{26}$" # This is case-insensitive + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive ProjectName: description: Renku project name type: string diff --git a/components/renku_data_services/project/apispec.py b/components/renku_data_services/project/apispec.py index 7e01ab2b1..9c38c906c 100644 --- a/components/renku_data_services/project/apispec.py +++ b/components/renku_data_services/project/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-06T05:55:31+00:00 +# timestamp: 2024-08-13T13:29:47+00:00 from __future__ import annotations @@ -112,7 +112,7 @@ class Project(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) name: str = Field( ..., diff --git a/components/renku_data_services/repositories/api.spec.yaml b/components/renku_data_services/repositories/api.spec.yaml index 5af8857a4..e02c0960e 100644 --- a/components/renku_data_services/repositories/api.spec.yaml +++ b/components/renku_data_services/repositories/api.spec.yaml @@ -103,7 +103,7 @@ components: type: string minLength: 26 maxLength: 26 - pattern: "^[A-Z0-9]{26}$" # This is case-insensitive + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive ProviderId: description: ID of a OAuth2 provider, e.g. "gitlab.com". type: string diff --git a/components/renku_data_services/repositories/apispec.py b/components/renku_data_services/repositories/apispec.py index 13dccc8bf..6ec16da6e 100644 --- a/components/renku_data_services/repositories/apispec.py +++ b/components/renku_data_services/repositories/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-06T05:55:37+00:00 +# timestamp: 2024-08-13T13:29:50+00:00 from __future__ import annotations @@ -61,6 +61,6 @@ class RepositoryProviderMatch(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) repository_metadata: Optional[RepositoryMetadata] = None diff --git a/components/renku_data_services/secrets/api.spec.yaml b/components/renku_data_services/secrets/api.spec.yaml index 44a435f0e..df85bf7ce 100644 --- a/components/renku_data_services/secrets/api.spec.yaml +++ b/components/renku_data_services/secrets/api.spec.yaml @@ -91,7 +91,7 @@ components: type: string minLength: 26 maxLength: 26 - pattern: "^[A-Z0-9]{26}$" # This is case-insensitive + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive KeyMapping: description: A mapping between secret_ids and names where names will be used as key values in the created K8s secret. type: object diff --git a/components/renku_data_services/secrets/apispec.py b/components/renku_data_services/secrets/apispec.py index 89beb3611..5cc27a959 100644 --- a/components/renku_data_services/secrets/apispec.py +++ b/components/renku_data_services/secrets/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-06T05:55:34+00:00 +# timestamp: 2024-08-13T13:29:49+00:00 from __future__ import annotations @@ -20,7 +20,7 @@ class Ulid(RootModel[str]): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) diff --git a/components/renku_data_services/session/api.spec.yaml b/components/renku_data_services/session/api.spec.yaml index fcb908e17..529499035 100644 --- a/components/renku_data_services/session/api.spec.yaml +++ b/components/renku_data_services/session/api.spec.yaml @@ -430,7 +430,7 @@ components: type: string minLength: 26 maxLength: 26 - pattern: "^[A-Z0-9]{26}$" # This is case-insensitive + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive SessionName: description: Renku session name type: string diff --git a/components/renku_data_services/storage/apispec.py b/components/renku_data_services/storage/apispec.py index fd934faea..d72e04dc9 100644 --- a/components/renku_data_services/storage/apispec.py +++ b/components/renku_data_services/storage/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-07T20:26:52+00:00 +# timestamp: 2024-08-13T13:29:46+00:00 from __future__ import annotations diff --git a/components/renku_data_services/users/api.spec.yaml b/components/renku_data_services/users/api.spec.yaml index 4c4dd4af5..65999fe80 100644 --- a/components/renku_data_services/users/api.spec.yaml +++ b/components/renku_data_services/users/api.spec.yaml @@ -430,7 +430,7 @@ components: type: string minLength: 26 maxLength: 26 - pattern: "^[A-Z0-9]{26}$" # This is case-insensitive + pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive ModificationDate: description: The date and time the secret was created or modified (this is always in UTC) type: string diff --git a/components/renku_data_services/users/apispec.py b/components/renku_data_services/users/apispec.py index 9c1697bc1..8734b6fd3 100644 --- a/components/renku_data_services/users/apispec.py +++ b/components/renku_data_services/users/apispec.py @@ -1,6 +1,6 @@ # generated by datamodel-codegen: # filename: api.spec.yaml -# timestamp: 2024-08-06T11:43:01+00:00 +# timestamp: 2024-08-13T13:29:47+00:00 from __future__ import annotations @@ -149,7 +149,7 @@ class SecretWithId(BaseAPISpec): description="ULID identifier", max_length=26, min_length=26, - pattern="^[A-Z0-9]{26}$", + pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$", ) name: str = Field( ..., From 1505e3d0f6f02076570b0961e82c23dcb036bc83 Mon Sep 17 00:00:00 2001 From: Ralf Grubenmann Date: Fri, 9 Aug 2024 15:56:36 +0200 Subject: [PATCH 4/5] refactor: add sqlalchemy ulid type to all relevant models --- .../background_jobs/core.py | 7 +- components/renku_data_services/authz/authz.py | 100 ++++++++++-------- .../renku_data_services/authz/models.py | 4 +- .../connected_services/apispec_base.py | 9 +- .../connected_services/blueprints.py | 13 +-- .../connected_services/db.py | 11 +- .../connected_services/models.py | 4 +- .../connected_services/orm.py | 3 +- .../message_queue/converters.py | 32 +++--- .../namespace/apispec_base.py | 9 +- .../renku_data_services/namespace/models.py | 8 +- .../renku_data_services/namespace/orm.py | 11 +- .../renku_data_services/project/blueprints.py | 10 +- components/renku_data_services/project/db.py | 16 +-- .../renku_data_services/project/models.py | 4 +- components/renku_data_services/project/orm.py | 6 +- .../repositories/apispec_base.py | 9 +- .../renku_data_services/repositories/db.py | 3 +- .../repositories/models.py | 4 +- .../renku_data_services/secrets/core.py | 3 +- .../renku_data_services/secrets/models.py | 3 +- components/renku_data_services/secrets/orm.py | 3 +- .../session/apispec_base.py | 9 +- components/renku_data_services/session/db.py | 14 +-- .../renku_data_services/session/models.py | 3 +- components/renku_data_services/session/orm.py | 3 +- .../storage/apispec_base.py | 15 ++- .../renku_data_services/storage/blueprints.py | 4 +- components/renku_data_services/storage/db.py | 8 +- .../renku_data_services/storage/models.py | 14 ++- components/renku_data_services/storage/orm.py | 9 +- .../authz/test_authorization.py | 41 +++---- 32 files changed, 245 insertions(+), 147 deletions(-) diff --git a/bases/renku_data_services/background_jobs/core.py b/bases/renku_data_services/background_jobs/core.py index 28614358b..289fbbbee 100644 --- a/bases/renku_data_services/background_jobs/core.py +++ b/bases/renku_data_services/background_jobs/core.py @@ -11,6 +11,7 @@ SubjectFilter, WriteRelationshipsRequest, ) +from ulid import ULID from renku_data_services.authz.authz import Authz, ResourceType, _AuthzConverter, _Relation from renku_data_services.authz.models import Scope @@ -133,7 +134,7 @@ async def fix_mismatched_project_namespace_ids(config: SyncConfig) -> None: relation=rel.relationship.relation, subject=SubjectReference( object=ObjectReference( - object_type=ResourceType.group.value, object_id=correct_group_id + object_type=ResourceType.group.value, object_id=str(correct_group_id) ) ), ), @@ -185,7 +186,7 @@ async def migrate_groups_make_all_public(config: SyncConfig) -> None: all_users = SubjectReference(object=_AuthzConverter.all_users()) all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users()) for group_id in groups_to_process: - group_res = _AuthzConverter.group(group_id) + group_res = _AuthzConverter.group(ULID.from_str(group_id)) all_users_are_viewers = Relationship( resource=group_res, relation=_Relation.public_viewer.value, @@ -244,7 +245,7 @@ async def migrate_user_namespaces_make_all_public(config: SyncConfig) -> None: all_users = SubjectReference(object=_AuthzConverter.all_users()) all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users()) for ns_id in namespaces_to_process: - namespace_res = _AuthzConverter.user_namespace(ns_id) + namespace_res = _AuthzConverter.user_namespace(ULID.from_str(ns_id)) all_users_are_viewers = Relationship( resource=namespace_res, relation=_Relation.public_viewer.value, diff --git a/components/renku_data_services/authz/authz.py b/components/renku_data_services/authz/authz.py index 98f719a3e..540dc4c62 100644 --- a/components/renku_data_services/authz/authz.py +++ b/components/renku_data_services/authz/authz.py @@ -5,7 +5,7 @@ from dataclasses import dataclass, field from enum import StrEnum from functools import wraps -from typing import ClassVar, Concatenate, ParamSpec, Protocol, TypeVar +from typing import ClassVar, Concatenate, ParamSpec, Protocol, TypeVar, cast from authzed.api.v1 import AsyncClient from authzed.api.v1.core_pb2 import ObjectReference, Relationship, RelationshipUpdate, SubjectReference, ZedToken @@ -26,6 +26,7 @@ ) from sanic.log import logger from sqlalchemy.ext.asyncio import AsyncSession +from ulid import ULID from renku_data_services import base_models from renku_data_services.authz.config import AuthzConfig @@ -132,8 +133,8 @@ class AuthzOperation(StrEnum): class _AuthzConverter: @staticmethod - def project(id: str) -> ObjectReference: - return ObjectReference(object_type=ResourceType.project.value, object_id=id) + def project(id: ULID) -> ObjectReference: + return ObjectReference(object_type=ResourceType.project.value, object_id=str(id)) @staticmethod def user(id: str | None) -> ObjectReference: @@ -162,25 +163,25 @@ def all_users() -> ObjectReference: return ObjectReference(object_type=ResourceType.user, object_id="*") @staticmethod - def group(id: str) -> ObjectReference: - return ObjectReference(object_type=ResourceType.group, object_id=id) + def group(id: ULID) -> ObjectReference: + return ObjectReference(object_type=ResourceType.group, object_id=str(id)) @staticmethod - def user_namespace(id: str) -> ObjectReference: - return ObjectReference(object_type=ResourceType.user_namespace, object_id=id) + def user_namespace(id: ULID) -> ObjectReference: + return ObjectReference(object_type=ResourceType.user_namespace, object_id=str(id)) @staticmethod - def to_object(resource_type: ResourceType, resource_id: str | int) -> ObjectReference: + def to_object(resource_type: ResourceType, resource_id: str | ULID | int) -> ObjectReference: match (resource_type, resource_id): - case (ResourceType.project, sid) if isinstance(sid, str): + case (ResourceType.project, sid) if isinstance(sid, ULID): return _AuthzConverter.project(sid) case (ResourceType.user, sid) if isinstance(sid, str) or sid is None: return _AuthzConverter.user(sid) case (ResourceType.anonymous_user, _): return _AuthzConverter.anonymous_users() - case (ResourceType.user_namespace, rid) if isinstance(rid, str): + case (ResourceType.user_namespace, rid) if isinstance(rid, ULID): return _AuthzConverter.user_namespace(rid) - case (ResourceType.group, rid) if isinstance(rid, str): + case (ResourceType.group, rid) if isinstance(rid, ULID): return _AuthzConverter.group(rid) raise errors.ProgrammingError( message=f"Unexpected or unknown resource type when checking permissions {resource_type}" @@ -242,23 +243,26 @@ async def decorated_function( return decorator +_ID = TypeVar("_ID", str, ULID) + + def _is_allowed( operation: Scope, ) -> Callable[ - [Callable[Concatenate["Authz", base_models.APIUser, ResourceType, str, _P], Awaitable[_T]]], - Callable[Concatenate["Authz", base_models.APIUser, ResourceType, str, _P], Awaitable[_T]], + [Callable[Concatenate["Authz", base_models.APIUser, ResourceType, _ID, _P], Awaitable[_T]]], + Callable[Concatenate["Authz", base_models.APIUser, ResourceType, _ID, _P], Awaitable[_T]], ]: """A decorator that checks if the operation on a resource is allowed or not.""" def decorator( - f: Callable[Concatenate["Authz", base_models.APIUser, ResourceType, str, _P], Awaitable[_T]], - ) -> Callable[Concatenate["Authz", base_models.APIUser, ResourceType, str, _P], Awaitable[_T]]: + f: Callable[Concatenate["Authz", base_models.APIUser, ResourceType, _ID, _P], Awaitable[_T]], + ) -> Callable[Concatenate["Authz", base_models.APIUser, ResourceType, _ID, _P], Awaitable[_T]]: @wraps(f) async def decorated_function( self: "Authz", user: base_models.APIUser, resource_type: ResourceType, - resource_id: str, + resource_id: _ID, *args: _P.args, **kwargs: _P.kwargs, ) -> _T: @@ -294,7 +298,7 @@ def client(self) -> AsyncClient: return self._client async def _has_permission( - self, user: base_models.APIUser, resource_type: ResourceType, resource_id: str | None, scope: Scope + self, user: base_models.APIUser, resource_type: ResourceType, resource_id: str | ULID | None, scope: Scope ) -> tuple[bool, ZedToken | None]: """Checks whether the provided user has a specific permission on the specific resource.""" if not resource_id: @@ -317,7 +321,7 @@ async def _has_permission( return response.permissionship == CheckPermissionResponse.PERMISSIONSHIP_HAS_PERMISSION, response.checked_at async def has_permission( - self, user: base_models.APIUser, resource_type: ResourceType, resource_id: str, scope: Scope + self, user: base_models.APIUser, resource_type: ResourceType, resource_id: str | ULID, scope: Scope ) -> bool: """Checks whether the provided user has a specific permission on the specific resource.""" res, _ = await self._has_permission(user, resource_type, resource_id, scope) @@ -386,24 +390,25 @@ async def members( self, user: base_models.APIUser, resource_type: ResourceType, - resource_id: str, + resource_id: ULID, role: Role | None = None, *, zed_token: ZedToken | None = None, ) -> list[Member]: """Get all users that are members of a resource, if role is None then all roles are retrieved.""" + resource_id_str = str(resource_id) consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True) sub_filter = SubjectFilter(subject_type=ResourceType.user.value) rel_filter = RelationshipFilter( resource_type=resource_type, - optional_resource_id=resource_id, + optional_resource_id=resource_id_str, optional_subject_filter=sub_filter, ) if role: relation = _Relation.from_role(role) rel_filter = RelationshipFilter( resource_type=resource_type, - optional_resource_id=resource_id, + optional_resource_id=resource_id_str, optional_relation=relation, optional_subject_filter=sub_filter, ) @@ -500,7 +505,7 @@ async def _get_authz_change( ) authz_change.extend(db_repo.authz._add_user_namespace(res.namespace)) case _: - resource_id: str | None = "unknown" + resource_id: str | ULID | None = "unknown" if isinstance(result, (Project, Namespace, Group)): resource_id = result.id elif isinstance(result, (ProjectUpdate, NamespaceUpdate, GroupUpdate)): @@ -575,7 +580,7 @@ def _add_project(self, project: Project) -> _AuthzChange: object=( _AuthzConverter.user_namespace(project.namespace.id) if project.namespace.kind == NamespaceKind.user - else _AuthzConverter.group(project.namespace.underlying_resource_id) + else _AuthzConverter.group(cast(ULID, project.namespace.underlying_resource_id)) ) ) project_in_platform = Relationship( @@ -619,7 +624,7 @@ async def _remove_project( ) -> _AuthzChange: """Remove the relationships associated with the project.""" consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True) - rel_filter = RelationshipFilter(resource_type=ResourceType.project.value, optional_resource_id=project.id) + rel_filter = RelationshipFilter(resource_type=ResourceType.project.value, optional_resource_id=str(project.id)) responses: AsyncIterable[ReadRelationshipsResponse] = self.client.ReadRelationships( ReadRelationshipsRequest(consistency=consistency, relationship_filter=rel_filter) ) @@ -640,6 +645,7 @@ async def _update_project_visibility( self, user: base_models.APIUser, project: Project, *, zed_token: ZedToken | None = None ) -> _AuthzChange: """Update the visibility of the project in the authorization database.""" + project_id_str = str(project.id) consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True) project_res = _AuthzConverter.project(project.id) all_users_sub = SubjectReference(object=_AuthzConverter.all_users()) @@ -668,7 +674,7 @@ async def _update_project_visibility( ) rel_filter = RelationshipFilter( resource_type=ResourceType.project.value, - optional_resource_id=project.id, + optional_resource_id=project_id_str, optional_subject_filter=SubjectFilter( subject_type=ResourceType.user.value, optional_subject_id=all_users_sub.object.object_id ), @@ -678,7 +684,7 @@ async def _update_project_visibility( ) rel_filter = RelationshipFilter( resource_type=ResourceType.project.value, - optional_resource_id=project.id, + optional_resource_id=project_id_str, optional_subject_filter=SubjectFilter( subject_type=ResourceType.anonymous_user.value, optional_subject_id=anon_users_sub.object.object_id, @@ -728,7 +734,7 @@ async def _update_project_namespace( project_res = _AuthzConverter.project(project.id) project_namespace_filter = RelationshipFilter( resource_type=ResourceType.project.value, - optional_resource_id=project.id, + optional_resource_id=str(project.id), optional_relation=_Relation.project_namespace.value, ) current_namespace: ReadRelationshipsResponse | None = await anext( @@ -752,10 +758,14 @@ async def _update_project_namespace( else SubjectReference(object=_AuthzConverter.user_namespace(project.namespace.id)) ) old_namespace_sub = ( - SubjectReference(object=_AuthzConverter.group(current_namespace.relationship.subject.object.object_id)) + SubjectReference( + object=_AuthzConverter.group(ULID.from_str(current_namespace.relationship.subject.object.object_id)) + ) if current_namespace.relationship.subject.object.object_type == ResourceType.group.value else SubjectReference( - object=_AuthzConverter.user_namespace(current_namespace.relationship.subject.object.object_id) + object=_AuthzConverter.user_namespace( + ULID.from_str(current_namespace.relationship.subject.object.object_id) + ) ) ) new_namespace = Relationship( @@ -804,7 +814,7 @@ async def upsert_project_members( self, user: base_models.APIUser, resource_type: ResourceType, - resource_id: str, + resource_id: ULID, members: list[Member], *, zed_token: ZedToken | None = None, @@ -813,13 +823,14 @@ async def upsert_project_members( Returns the list that was updated/inserted. """ + resource_id_str = str(resource_id) consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True) project_res = _AuthzConverter.project(resource_id) add_members: list[RelationshipUpdate] = [] undo: list[RelationshipUpdate] = [] output: list[MembershipChange] = [] expected_user_roles = {_Relation.viewer.value, _Relation.owner.value, _Relation.editor.value} - existing_owners_rels = await self._get_resource_owners(resource_type, resource_id, consistency) + existing_owners_rels = await self._get_resource_owners(resource_type, resource_id_str, consistency) n_existing_owners = len(existing_owners_rels) for member in members: rel = Relationship( @@ -829,7 +840,7 @@ async def upsert_project_members( ) existing_rel_filter = RelationshipFilter( resource_type=resource_type.value, - optional_resource_id=resource_id, + optional_resource_id=resource_id_str, optional_subject_filter=SubjectFilter( subject_type=ResourceType.user, optional_subject_id=member.user_id ), @@ -922,24 +933,25 @@ async def remove_project_members( self, user: base_models.APIUser, resource_type: ResourceType, - resource_id: str, + resource_id: ULID, user_ids: list[str], *, zed_token: ZedToken | None = None, ) -> list[MembershipChange]: """Remove the specific members from the project, then return the list of members that were removed.""" + resource_id_str = str(resource_id) consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True) add_members: list[RelationshipUpdate] = [] remove_members: list[RelationshipUpdate] = [] output: list[MembershipChange] = [] - existing_owners_rels = await self._get_resource_owners(resource_type, resource_id, consistency) + existing_owners_rels = await self._get_resource_owners(resource_type, resource_id_str, consistency) existing_owners: set[str] = {rel.relationship.subject.object.object_id for rel in existing_owners_rels} for user_id in user_ids: if user_id == "*": raise errors.ValidationError(message="Cannot remove a project member with ID '*'") existing_rel_filter = RelationshipFilter( resource_type=resource_type.value, - optional_resource_id=resource_id, + optional_resource_id=resource_id_str, optional_subject_filter=SubjectFilter(subject_type=ResourceType.user, optional_subject_id=user_id), ) existing_rels: AsyncIterable[ReadRelationshipsResponse] = self.client.ReadRelationships( @@ -1084,7 +1096,7 @@ async def _remove_group( message="Cannot remove a group in the authorization database if the group has no ID" ) consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True) - rel_filter = RelationshipFilter(resource_type=ResourceType.group.value, optional_resource_id=group.id) + rel_filter = RelationshipFilter(resource_type=ResourceType.group.value, optional_resource_id=str(group.id)) responses = self.client.ReadRelationships( ReadRelationshipsRequest(consistency=consistency, relationship_filter=rel_filter) ) @@ -1104,7 +1116,7 @@ async def upsert_group_members( self, user: base_models.APIUser, resource_type: ResourceType, - resource_id: str, + resource_id: ULID, members: list[Member], *, zed_token: ZedToken | None = None, @@ -1115,8 +1127,9 @@ async def upsert_group_members( add_members: list[RelationshipUpdate] = [] undo: list[RelationshipUpdate] = [] output: list[MembershipChange] = [] + resource_id_str = str(resource_id) expected_user_roles = {_Relation.viewer.value, _Relation.owner.value, _Relation.editor.value} - existing_owners_rels = await self._get_resource_owners(resource_type, resource_id, consistency) + existing_owners_rels = await self._get_resource_owners(resource_type, resource_id_str, consistency) n_existing_owners = len(existing_owners_rels) for member in members: rel = Relationship( @@ -1126,7 +1139,7 @@ async def upsert_group_members( ) existing_rel_filter = RelationshipFilter( resource_type=resource_type.value, - optional_resource_id=resource_id, + optional_resource_id=resource_id_str, optional_subject_filter=SubjectFilter( subject_type=ResourceType.user, optional_subject_id=member.user_id ), @@ -1222,7 +1235,7 @@ async def remove_group_members( self, user: base_models.APIUser, resource_type: ResourceType, - resource_id: str, + resource_id: ULID, user_ids: list[str], *, zed_token: ZedToken | None = None, @@ -1233,12 +1246,13 @@ async def remove_group_members( remove_members: list[RelationshipUpdate] = [] output: list[MembershipChange] = [] existing_owners_rels: list[ReadRelationshipsResponse] | None = None + resource_id_str = str(resource_id) for user_id in user_ids: if user_id == "*": raise errors.ValidationError(message="Cannot remove a group member with ID '*'") existing_rel_filter = RelationshipFilter( resource_type=resource_type.value, - optional_resource_id=resource_id, + optional_resource_id=resource_id_str, optional_subject_filter=SubjectFilter(subject_type=ResourceType.user, optional_subject_id=user_id), ) existing_rels: AsyncIterable[ReadRelationshipsResponse] = self.client.ReadRelationships( @@ -1249,7 +1263,9 @@ async def remove_group_members( async for existing_rel in existing_rels: if existing_rel.relationship.relation == _Relation.owner.value: if existing_owners_rels is None: - existing_owners_rels = await self._get_resource_owners(resource_type, resource_id, consistency) + existing_owners_rels = await self._get_resource_owners( + resource_type, resource_id_str, consistency + ) if len(existing_owners_rels) == 1: raise errors.ValidationError( message="You are trying to remove the single last owner of the group, " diff --git a/components/renku_data_services/authz/models.py b/components/renku_data_services/authz/models.py index 392db775d..f5055e06f 100644 --- a/components/renku_data_services/authz/models.py +++ b/components/renku_data_services/authz/models.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from enum import Enum +from ulid import ULID + from renku_data_services.errors import errors from renku_data_services.namespace.apispec import GroupRole @@ -56,7 +58,7 @@ class Member: role: Role user_id: str - resource_id: str + resource_id: str | ULID class Change(Enum): diff --git a/components/renku_data_services/connected_services/apispec_base.py b/components/renku_data_services/connected_services/apispec_base.py index 827d18484..77b1a89d8 100644 --- a/components/renku_data_services/connected_services/apispec_base.py +++ b/components/renku_data_services/connected_services/apispec_base.py @@ -1,6 +1,7 @@ """Base models for API specifications.""" -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator +from ulid import ULID class BaseAPISpec(BaseModel): @@ -14,6 +15,12 @@ class Config: # this rust crate does not support lookahead regex syntax but we need it in this component regex_engine = "python-re" + @field_validator("id", mode="before", check_fields=False) + @classmethod + def serialize_id(cls, id: str | ULID) -> str: + """Custom serializer that can handle ULIDs.""" + return str(id) + class AuthorizeParams(BaseAPISpec): """The schema for the query parameters used in the authorize request.""" diff --git a/components/renku_data_services/connected_services/blueprints.py b/components/renku_data_services/connected_services/blueprints.py index 28cfd6a83..961abed2f 100644 --- a/components/renku_data_services/connected_services/blueprints.py +++ b/components/renku_data_services/connected_services/blueprints.py @@ -7,6 +7,7 @@ from sanic.log import logger from sanic.response import JSONResponse from sanic_ext import validate +from ulid import ULID import renku_data_services.base_models as base_models from renku_data_services.base_api.auth import authenticate, only_admins, only_authenticated @@ -150,34 +151,34 @@ def get_one(self) -> BlueprintFactoryResponse: """Get a specific OAuth2 connection.""" @authenticate(self.authenticator) - async def _get_one(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse: + async def _get_one(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse: connection = await self.connected_services_repo.get_oauth2_connection( connection_id=connection_id, user=user ) return validated_json(apispec.Connection, connection) - return "/oauth2/connections/", ["GET"], _get_one + return "/oauth2/connections/", ["GET"], _get_one def get_account(self) -> BlueprintFactoryResponse: """Get the account information for a specific OAuth2 connection.""" @authenticate(self.authenticator) - async def _get_account(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse: + async def _get_account(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse: account = await self.connected_services_repo.get_oauth2_connected_account( connection_id=connection_id, user=user ) return validated_json(apispec.ConnectedAccount, account) - return "/oauth2/connections//account", ["GET"], _get_account + return "/oauth2/connections//account", ["GET"], _get_account def get_token(self) -> BlueprintFactoryResponse: """Get the access token for a specific OAuth2 connection.""" @authenticate(self.authenticator) - async def _get_token(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse: + async def _get_token(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse: token = await self.connected_services_repo.get_oauth2_connection_token( connection_id=connection_id, user=user ) return json(token.dump_for_api()) - return "/oauth2/connections//token", ["GET"], _get_token + return "/oauth2/connections//token", ["GET"], _get_token diff --git a/components/renku_data_services/connected_services/db.py b/components/renku_data_services/connected_services/db.py index dfb4ebcce..46186c7de 100644 --- a/components/renku_data_services/connected_services/db.py +++ b/components/renku_data_services/connected_services/db.py @@ -11,6 +11,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload +from ulid import ULID import renku_data_services.base_models as base_models from renku_data_services import errors @@ -282,7 +283,7 @@ async def get_oauth2_connections( connections = result.all() return [c.dump() for c in connections] - async def get_oauth2_connection(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2Connection: + async def get_oauth2_connection(self, connection_id: ULID, user: base_models.APIUser) -> models.OAuth2Connection: """Get one OAuth2 connection from the database.""" if not user.is_authenticated or user.id is None: raise errors.MissingResourceError( @@ -303,7 +304,7 @@ async def get_oauth2_connection(self, connection_id: str, user: base_models.APIU return connection.dump() async def get_oauth2_connected_account( - self, connection_id: str, user: base_models.APIUser + self, connection_id: ULID, user: base_models.APIUser ) -> models.ConnectedAccount: """Get the account information from a OAuth2 connection.""" async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, adapter): @@ -316,7 +317,9 @@ async def get_oauth2_connected_account( account = adapter.api_validate_account_response(response) return account - async def get_oauth2_connection_token(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2TokenSet: + async def get_oauth2_connection_token( + self, connection_id: ULID, user: base_models.APIUser + ) -> models.OAuth2TokenSet: """Get the OAuth2 access token from one connection from the database.""" async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, _): await oauth2_client.ensure_active_token(oauth2_client.token) @@ -325,7 +328,7 @@ async def get_oauth2_connection_token(self, connection_id: str, user: base_model @asynccontextmanager async def get_async_oauth2_client( - self, connection_id: str, user: base_models.APIUser + self, connection_id: ULID, user: base_models.APIUser ) -> AsyncGenerator[tuple[AsyncOAuth2Client, schemas.OAuth2ConnectionORM, ProviderAdapter], None]: """Get the AsyncOAuth2Client for the given connection_id and user.""" if not user.is_authenticated or user.id is None: diff --git a/components/renku_data_services/connected_services/models.py b/components/renku_data_services/connected_services/models.py index b018abcc0..10f67cf0c 100644 --- a/components/renku_data_services/connected_services/models.py +++ b/components/renku_data_services/connected_services/models.py @@ -4,6 +4,8 @@ from datetime import UTC, datetime from typing import Any +from ulid import ULID + from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind @@ -28,7 +30,7 @@ class OAuth2Client: class OAuth2Connection: """OAuth2 connection model.""" - id: str + id: ULID provider_id: str status: ConnectionStatus diff --git a/components/renku_data_services/connected_services/orm.py b/components/renku_data_services/connected_services/orm.py index 1b998ac69..6a2762ac5 100644 --- a/components/renku_data_services/connected_services/orm.py +++ b/components/renku_data_services/connected_services/orm.py @@ -11,6 +11,7 @@ from renku_data_services.connected_services import models from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind +from renku_data_services.utils.sqlalchemy import ULIDType JSONVariant = JSON().with_variant(JSONB(), "postgresql") @@ -72,7 +73,7 @@ class OAuth2ConnectionORM(BaseORM): """An OAuth2 connection.""" __tablename__ = "oauth2_connections" - id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False) + id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False) user_id: Mapped[str] = mapped_column("user_id", String()) client_id: Mapped[str] = mapped_column(ForeignKey(OAuth2ClientORM.id, ondelete="CASCADE"), index=True) client: Mapped[OAuth2ClientORM] = relationship(init=False, repr=False) diff --git a/components/renku_data_services/message_queue/converters.py b/components/renku_data_services/message_queue/converters.py index c0dcfb608..29c0e0587 100644 --- a/components/renku_data_services/message_queue/converters.py +++ b/components/renku_data_services/message_queue/converters.py @@ -35,13 +35,14 @@ def to_events( raise errors.EventError( message=f"Cannot create an event of type {event_type} for a project which has no ID" ) + project_id_str = str(project.id) match event_type: case v2.ProjectCreated: return [ Event( "project.created", v2.ProjectCreated( - id=project.id, + id=project_id_str, name=project.name, namespace=project.namespace.slug, slug=project.slug, @@ -56,7 +57,7 @@ def to_events( Event( "projectAuth.added", v2.ProjectMemberAdded( - projectId=project.id, + projectId=project_id_str, userId=project.created_by, role=v2.MemberRole.OWNER, ), @@ -67,7 +68,7 @@ def to_events( Event( "project.updated", v2.ProjectUpdated( - id=project.id, + id=project_id_str, name=project.name, namespace=project.namespace.slug, slug=project.slug, @@ -79,7 +80,7 @@ def to_events( ) ] case v2.ProjectRemoved: - return [Event("project.removed", v2.ProjectRemoved(id=project.id))] + return [Event("project.removed", v2.ProjectRemoved(id=project_id_str))] case _: raise errors.EventError(message=f"Trying to convert a project to an unknown event type {event_type}") @@ -145,13 +146,14 @@ class _ProjectAuthzEventConverter: def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event]: output: list[Event] = [] for change in member_changes: + resource_id = str(change.member.resource_id) match change.change: case authz_models.Change.UPDATE: output.append( Event( "projectAuth.updated", v2.ProjectMemberUpdated( - projectId=change.member.resource_id, + projectId=resource_id, userId=change.member.user_id, role=_convert_member_role(change.member.role), ), @@ -162,7 +164,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event Event( "projectAuth.removed", v2.ProjectMemberRemoved( - projectId=change.member.resource_id, + projectId=resource_id, userId=change.member.user_id, ), ) @@ -172,7 +174,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event Event( "projectAuth.added", v2.ProjectMemberAdded( - projectId=change.member.resource_id, + projectId=resource_id, userId=change.member.user_id, role=_convert_member_role(change.member.role), ), @@ -191,13 +193,14 @@ class _GroupAuthzEventConverter: def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event]: output: list[Event] = [] for change in member_changes: + resource_id = str(change.member.resource_id) match change.change: case authz_models.Change.UPDATE: output.append( Event( "memberGroup.updated", v2.ProjectMemberUpdated( - projectId=change.member.resource_id, + projectId=resource_id, userId=change.member.user_id, role=_convert_member_role(change.member.role), ), @@ -208,7 +211,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event Event( "memberGroup.removed", v2.ProjectMemberRemoved( - projectId=change.member.resource_id, + projectId=resource_id, userId=change.member.user_id, ), ) @@ -218,7 +221,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event Event( "memberGroup.added", v2.ProjectMemberAdded( - projectId=change.member.resource_id, + projectId=resource_id, userId=change.member.user_id, role=_convert_member_role(change.member.role), ), @@ -239,32 +242,33 @@ def to_events(group: group_models.Group, event_type: type[AvroModel] | type[even raise errors.ProgrammingError( message="Cannot send group events to the message queue for a group that does not have an ID" ) + group_id = str(group.id) match event_type: case v2.GroupAdded: return [ Event( "group.added", v2.GroupAdded( - id=group.id, name=group.name, description=group.description, namespace=group.slug + id=group_id, name=group.name, description=group.description, namespace=group.slug ), ), Event( "memberGroup.added", v2.GroupMemberAdded( - groupId=group.id, + groupId=group_id, userId=group.created_by, role=v2.MemberRole.OWNER, ), ), ] case v2.GroupRemoved: - return [Event("group.removed", v2.GroupRemoved(id=group.id))] + return [Event("group.removed", v2.GroupRemoved(id=group_id))] case v2.GroupUpdated: return [ Event( "group.updated", v2.GroupUpdated( - id=group.id, name=group.name, description=group.description, namespace=group.slug + id=group_id, name=group.name, description=group.description, namespace=group.slug ), ) ] diff --git a/components/renku_data_services/namespace/apispec_base.py b/components/renku_data_services/namespace/apispec_base.py index c888c3ba5..476c07927 100644 --- a/components/renku_data_services/namespace/apispec_base.py +++ b/components/renku_data_services/namespace/apispec_base.py @@ -1,6 +1,7 @@ """Base models for API specifications.""" -from pydantic import BaseModel +from pydantic import BaseModel, field_validator +from ulid import ULID class BaseAPISpec(BaseModel): @@ -13,3 +14,9 @@ class Config: # NOTE: By default the pydantic library does not use python for regex but a rust crate # this rust crate does not support lookahead regex syntax but we need it in this component regex_engine = "python-re" + + @field_validator("id", mode="before", check_fields=False) + @classmethod + def serialize_id(cls, id: str | ULID) -> str: + """Custom serializer that can handle ULIDs.""" + return str(id) diff --git a/components/renku_data_services/namespace/models.py b/components/renku_data_services/namespace/models.py index 9cb93f337..05a621fe0 100644 --- a/components/renku_data_services/namespace/models.py +++ b/components/renku_data_services/namespace/models.py @@ -4,6 +4,8 @@ from datetime import datetime from enum import Enum +from ulid import ULID + from renku_data_services.authz.models import Role @@ -16,7 +18,7 @@ class Group: created_by: str creation_date: datetime description: str | None = None - id: str | None = None + id: ULID | None = None @dataclass @@ -50,11 +52,11 @@ class NamespaceKind(str, Enum): class Namespace: """A renku namespace.""" - id: str + id: ULID slug: str kind: NamespaceKind created_by: str - underlying_resource_id: str # The user or group ID depending on the Namespace kind + underlying_resource_id: ULID | str # The user or group ID depending on the Namespace kind latest_slug: str | None = None name: str | None = None diff --git a/components/renku_data_services/namespace/orm.py b/components/renku_data_services/namespace/orm.py index 1a7951d6c..741ba5a66 100644 --- a/components/renku_data_services/namespace/orm.py +++ b/components/renku_data_services/namespace/orm.py @@ -13,6 +13,7 @@ from renku_data_services.namespace import models from renku_data_services.users.models import UserInfo, UserWithNamespace from renku_data_services.users.orm import UserORM +from renku_data_services.utils.sqlalchemy import ULIDType class BaseORM(MappedAsDataclass, DeclarativeBase): @@ -27,7 +28,7 @@ class GroupORM(BaseORM): __tablename__ = "groups" - id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False) + id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False) name: Mapped[str] = mapped_column("name", String(99), index=True) created_by: Mapped[str] = mapped_column(ForeignKey(UserORM.keycloak_id), index=True, nullable=False) creation_date: Mapped[datetime] = mapped_column("creation_date", DateTime(timezone=True), server_default=func.now()) @@ -54,9 +55,9 @@ class NamespaceORM(BaseORM): CheckConstraint("(user_id IS NULL) <> (group_id IS NULL)", name="either_group_id_or_user_id_is_set"), ) - id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False) + id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False) slug: Mapped[str] = mapped_column(String(99), index=True, unique=True, nullable=False) - group_id: Mapped[str | None] = mapped_column( + group_id: Mapped[ULID | None] = mapped_column( ForeignKey(GroupORM.id, ondelete="CASCADE", name="namespaces_group_id_fk"), default=None, nullable=True, @@ -119,10 +120,10 @@ class NamespaceOldORM(BaseORM): __tablename__ = "namespaces_old" - id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False) + id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False) slug: Mapped[str] = mapped_column(String(99), index=True, nullable=False) created_at: Mapped[datetime] = mapped_column(nullable=False, index=True, init=False, server_default=func.now()) - latest_slug_id: Mapped[str] = mapped_column( + latest_slug_id: Mapped[ULID] = mapped_column( ForeignKey(NamespaceORM.id, ondelete="CASCADE"), nullable=False, index=True ) latest_slug: Mapped[NamespaceORM] = relationship(lazy="joined", init=False, viewonly=True, repr=False) diff --git a/components/renku_data_services/project/blueprints.py b/components/renku_data_services/project/blueprints.py index 2837eacaa..bdc3a4416 100644 --- a/components/renku_data_services/project/blueprints.py +++ b/components/renku_data_services/project/blueprints.py @@ -50,7 +50,7 @@ async def _get_all( ) return [ dict( - id=p.id, + id=str(p.id), name=p.name, namespace=p.namespace.slug, slug=p.slug, @@ -88,7 +88,7 @@ async def _post(_: Request, user: base_models.APIUser, body: apispec.ProjectPost result = await self.project_repo.insert_project(user, project) return json( dict( - id=result.id, + id=str(result.id), name=result.name, namespace=result.namespace.slug, slug=result.slug, @@ -120,7 +120,7 @@ async def _get_one(request: Request, user: base_models.APIUser, project_id: str) headers = {"ETag": project.etag} if project.etag is not None else None return json( dict( - id=project.id, + id=str(project.id), name=project.name, namespace=project.namespace.slug, slug=project.slug, @@ -153,7 +153,7 @@ async def _get_one_by_namespace_slug( headers = {"ETag": project.etag} if project.etag is not None else None return json( dict( - id=project.id, + id=str(project.id), name=project.name, namespace=project.namespace.slug, slug=project.slug, @@ -207,7 +207,7 @@ async def _patch( updated_project = project_update.new return json( dict( - id=updated_project.id, + id=str(updated_project.id), name=updated_project.name, namespace=updated_project.namespace.slug, slug=updated_project.slug, diff --git a/components/renku_data_services/project/db.py b/components/renku_data_services/project/db.py index f5055aa7f..701473d11 100644 --- a/components/renku_data_services/project/db.py +++ b/components/renku_data_services/project/db.py @@ -76,7 +76,7 @@ async def get_projects( async def get_project(self, user: base_models.APIUser, project_id: ULID) -> models.Project: """Get one project from the database.""" - authorized = await self.authz.has_permission(user, ResourceType.project, str(project_id), Scope.READ) + authorized = await self.authz.has_permission(user, ResourceType.project, project_id, Scope.READ) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." @@ -113,7 +113,7 @@ async def get_project_by_namespace_slug( authorized = await self.authz.has_permission( user=user, resource_type=ResourceType.project, - resource_id=str(project_orm.id), + resource_id=project_orm.id, scope=Scope.READ, ) if not authorized: @@ -171,7 +171,7 @@ async def insert_project( creation_date=datetime.now(UTC).replace(microsecond=0), keywords=project.keywords, ) - project_slug = schemas.ProjectSlug(slug, project_id=str(project_orm.id), namespace_id=ns.id) + project_slug = schemas.ProjectSlug(slug, project_id=project_orm.id, namespace_id=ns.id) session.add(project_slug) session.add(project_orm) @@ -212,7 +212,7 @@ async def update_project( if "namespace" in payload and payload["namespace"] != old_project.namespace: # NOTE: changing the namespace requires the user to be owner which means they should have DELETE permission required_scope = Scope.DELETE - authorized = await self.authz.has_permission(user, ResourceType.project, str(project_id), required_scope) + authorized = await self.authz.has_permission(user, ResourceType.project, project_id, required_scope) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id_str}' does not exist or you do not have access to it." @@ -273,7 +273,7 @@ async def delete_project( """Delete a project.""" if not session: raise errors.ProgrammingError(message="A database session is required") - authorized = await self.authz.has_permission(user, ResourceType.project, str(project_id), Scope.DELETE) + authorized = await self.authz.has_permission(user, ResourceType.project, project_id, Scope.DELETE) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." @@ -359,7 +359,7 @@ async def get_members( self, user: base_models.APIUser, project_id: ULID, *, session: AsyncSession | None = None ) -> list[Member]: """Get all members of a project.""" - members = await self.authz.members(user, ResourceType.project, str(project_id)) + members = await self.authz.members(user, ResourceType.project, project_id) members = [member for member in members if member.user_id and member.user_id != "*"] return members @@ -391,7 +391,7 @@ async def update_members( f"{requested_member_ids_set.difference(existing_member_ids)} cannot be found" ) - output = await self.authz.upsert_project_members(user, ResourceType.project, str(project_id), members) + output = await self.authz.upsert_project_members(user, ResourceType.project, project_id, members) return output @with_db_transaction @@ -404,5 +404,5 @@ async def delete_members( if len(user_ids) == 0: raise errors.ValidationError(message="Please request at least 1 member to be removed from the project") - members = await self.authz.remove_project_members(user, ResourceType.project, str(project_id), user_ids) + members = await self.authz.remove_project_members(user, ResourceType.project, project_id, user_ids) return members diff --git a/components/renku_data_services/project/models.py b/components/renku_data_services/project/models.py index 430e55f73..13c1bb9c1 100644 --- a/components/renku_data_services/project/models.py +++ b/components/renku_data_services/project/models.py @@ -4,6 +4,8 @@ from datetime import UTC, datetime from typing import Optional +from ulid import ULID + from renku_data_services.authz.models import Visibility from renku_data_services.namespace.models import Namespace from renku_data_services.utils.etag import compute_etag_from_timestamp @@ -37,7 +39,7 @@ def etag(self) -> str | None: class Project(BaseProject): """Base Project model.""" - id: str + id: ULID namespace: Namespace diff --git a/components/renku_data_services/project/orm.py b/components/renku_data_services/project/orm.py index 5dcdd22b8..b50232fe3 100644 --- a/components/renku_data_services/project/orm.py +++ b/components/renku_data_services/project/orm.py @@ -54,7 +54,7 @@ class ProjectORM(BaseORM): def dump(self) -> models.Project: """Create a project model from the ProjectORM.""" return models.Project( - id=str(self.id), + id=self.id, name=self.name, slug=self.slug.slug, namespace=self.slug.namespace.dump(), @@ -91,10 +91,10 @@ class ProjectSlug(BaseORM): id: Mapped[int] = mapped_column(primary_key=True, init=False) slug: Mapped[str] = mapped_column(String(99), index=True, nullable=False) - project_id: Mapped[str] = mapped_column( + project_id: Mapped[ULID] = mapped_column( ForeignKey(ProjectORM.id, ondelete="CASCADE", name="project_slugs_project_id_fk"), index=True ) - namespace_id: Mapped[str] = mapped_column( + namespace_id: Mapped[ULID] = mapped_column( ForeignKey(NamespaceORM.id, ondelete="CASCADE", name="project_slugs_namespace_id_fk"), index=True ) namespace: Mapped[NamespaceORM] = relationship(lazy="joined", init=False, repr=False, viewonly=True) diff --git a/components/renku_data_services/repositories/apispec_base.py b/components/renku_data_services/repositories/apispec_base.py index b22840c63..19939e851 100644 --- a/components/renku_data_services/repositories/apispec_base.py +++ b/components/renku_data_services/repositories/apispec_base.py @@ -1,6 +1,7 @@ """Base models for API specifications.""" -from pydantic import BaseModel, Field, HttpUrl +from pydantic import BaseModel, Field, HttpUrl, field_validator +from ulid import ULID class BaseAPISpec(BaseModel): @@ -14,6 +15,12 @@ class Config: # this rust crate does not support lookahead regex syntax but we need it in this component regex_engine = "python-re" + @field_validator("connection_id", mode="before", check_fields=False) + @classmethod + def serialize_connection_id(cls, connection_id: str | ULID) -> str: + """Custom serializer that can handle ULIDs.""" + return str(connection_id) + class RepositoryParams(BaseAPISpec): """The schema for the path parameters used in the repository requests.""" diff --git a/components/renku_data_services/repositories/db.py b/components/renku_data_services/repositories/db.py index 34051ba1b..c44a1a5ce 100644 --- a/components/renku_data_services/repositories/db.py +++ b/components/renku_data_services/repositories/db.py @@ -7,6 +7,7 @@ from httpx import AsyncClient as HttpClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from ulid import ULID import renku_data_services.base_models as base_models from renku_data_services import errors @@ -107,7 +108,7 @@ async def _get_repository_anonymously( ) async def _get_repository_authenticated( - self, connection_id: str, repository_url: str, user: base_models.APIUser, etag: str | None + self, connection_id: ULID, repository_url: str, user: base_models.APIUser, etag: str | None ) -> models.RepositoryProviderMatch | Literal["304"]: """Get the metadata about a repository using an OAuth2 connection.""" async with self.connected_services_repo.get_async_oauth2_client(connection_id=connection_id, user=user) as ( diff --git a/components/renku_data_services/repositories/models.py b/components/renku_data_services/repositories/models.py index 2e808b769..66c34853b 100644 --- a/components/renku_data_services/repositories/models.py +++ b/components/renku_data_services/repositories/models.py @@ -2,6 +2,8 @@ from dataclasses import dataclass +from ulid import ULID + @dataclass(frozen=True, eq=True, kw_only=True) class RepositoryPermissions: @@ -31,5 +33,5 @@ class RepositoryProviderMatch: """Repository provider match data.""" provider_id: str - connection_id: str | None + connection_id: ULID | None repository_metadata: RepositoryMetadata | None diff --git a/components/renku_data_services/secrets/core.py b/components/renku_data_services/secrets/core.py index c81a71b6f..ba625b943 100644 --- a/components/renku_data_services/secrets/core.py +++ b/components/renku_data_services/secrets/core.py @@ -2,7 +2,6 @@ import asyncio from base64 import b64encode -from typing import cast from cryptography.hazmat.primitives.asymmetric import rsa from kubernetes import client as k8s_client @@ -64,7 +63,7 @@ async def create_k8s_secret( raise decrypted_value = decrypt_string(decryption_key, user.id, secret.encrypted_value).encode() # type: ignore - key = secret.name if not key_mapping else key_mapping[cast(str, secret.id)] + key = secret.name if not key_mapping else key_mapping[str(secret.id)] decrypted_secrets[key] = b64encode(decrypted_value).decode() except Exception as e: # don't wrap the error, we don't want secrets accidentally leaking. diff --git a/components/renku_data_services/secrets/models.py b/components/renku_data_services/secrets/models.py index 39aefde3a..c1622e6a5 100644 --- a/components/renku_data_services/secrets/models.py +++ b/components/renku_data_services/secrets/models.py @@ -6,6 +6,7 @@ from kubernetes import client as k8s_client from pydantic import BaseModel, Field +from ulid import ULID class SecretKind(Enum): @@ -21,7 +22,7 @@ class Secret(BaseModel): name: str encrypted_value: bytes = Field(repr=False) encrypted_key: bytes = Field(repr=False) - id: str | None = Field(default=None, init=False) + id: ULID | None = Field(default=None, init=False) modification_date: datetime = Field(default_factory=lambda: datetime.now(UTC).replace(microsecond=0), init=False) kind: SecretKind diff --git a/components/renku_data_services/secrets/orm.py b/components/renku_data_services/secrets/orm.py index 25995fdf0..3291a7707 100644 --- a/components/renku_data_services/secrets/orm.py +++ b/components/renku_data_services/secrets/orm.py @@ -9,6 +9,7 @@ from renku_data_services.secrets import models from renku_data_services.users.orm import UserORM +from renku_data_services.utils.sqlalchemy import ULIDType metadata_obj = MetaData(schema="secrets") # Has to match alembic ini section name @@ -37,7 +38,7 @@ class SecretORM(BaseORM): modification_date: Mapped[datetime] = mapped_column( "modification_date", DateTime(timezone=True), default_factory=lambda: datetime.now(UTC).replace(microsecond=0) ) - id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False) + id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False) user_id: Mapped[Optional[str]] = mapped_column( "user_id", ForeignKey(UserORM.keycloak_id, ondelete="CASCADE"), default=None, index=True, nullable=True ) diff --git a/components/renku_data_services/session/apispec_base.py b/components/renku_data_services/session/apispec_base.py index ee344ebf7..a16833290 100644 --- a/components/renku_data_services/session/apispec_base.py +++ b/components/renku_data_services/session/apispec_base.py @@ -1,6 +1,7 @@ """Base models for API specifications.""" -from pydantic import BaseModel +from pydantic import BaseModel, field_validator +from ulid import ULID class BaseAPISpec(BaseModel): @@ -10,3 +11,9 @@ class Config: """Enables orm mode for pydantic.""" from_attributes = True + + @field_validator("id", mode="before", check_fields=False) + @classmethod + def serialize_id(cls, id: str | ULID) -> str: + """Custom serializer that can handle ULIDs.""" + return str(id) diff --git a/components/renku_data_services/session/db.py b/components/renku_data_services/session/db.py index ca2677806..8e488b7f1 100644 --- a/components/renku_data_services/session/db.py +++ b/components/renku_data_services/session/db.py @@ -133,7 +133,9 @@ async def get_launchers(self, user: base_models.APIUser) -> list[models.SessionL async def get_project_launchers(self, user: base_models.APIUser, project_id: str) -> list[models.SessionLauncher]: """Get all session launchers in a project from the database.""" - authorized = await self.project_authz.has_permission(user, ResourceType.project, project_id, Scope.READ) + authorized = await self.project_authz.has_permission( + user, ResourceType.project, ULID.from_str(project_id), Scope.READ + ) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." @@ -157,9 +159,7 @@ async def get_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> mo launcher = res.one_or_none() authorized = ( - await self.project_authz.has_permission( - user, ResourceType.project, str(launcher.project_id), Scope.READ - ) + await self.project_authz.has_permission(user, ResourceType.project, launcher.project_id, Scope.READ) if launcher is not None else False ) @@ -178,7 +178,9 @@ async def insert_launcher( raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.") project_id = new_launcher.project_id - authorized = await self.project_authz.has_permission(user, ResourceType.project, project_id, Scope.WRITE) + authorized = await self.project_authz.has_permission( + user, ResourceType.project, ULID.from_str(project_id), Scope.WRITE + ) if not authorized: raise errors.MissingResourceError( message=f"Project with id '{project_id}' does not exist or you do not have access to it." @@ -338,7 +340,7 @@ async def delete_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> authorized = await self.project_authz.has_permission( user, ResourceType.project, - str(launcher.project_id), + launcher.project_id, Scope.WRITE, ) if not authorized: diff --git a/components/renku_data_services/session/models.py b/components/renku_data_services/session/models.py index 9a660bb1c..422524152 100644 --- a/components/renku_data_services/session/models.py +++ b/components/renku_data_services/session/models.py @@ -4,6 +4,7 @@ from datetime import datetime from pydantic import BaseModel, model_validator +from ulid import ULID from renku_data_services import errors from renku_data_services.session.apispec import EnvironmentKind @@ -33,7 +34,7 @@ class Environment(BaseModel): class SessionLauncher(BaseModel): """Session launcher model.""" - id: str | None + id: ULID | None project_id: str name: str creation_date: datetime diff --git a/components/renku_data_services/session/orm.py b/components/renku_data_services/session/orm.py index 2d5cf7dea..4b61d548c 100644 --- a/components/renku_data_services/session/orm.py +++ b/components/renku_data_services/session/orm.py @@ -11,6 +11,7 @@ from renku_data_services.project.orm import ProjectORM from renku_data_services.session import models from renku_data_services.session.apispec import EnvironmentKind +from renku_data_services.utils.sqlalchemy import ULIDType metadata_obj = MetaData(schema="sessions") # Has to match alembic ini section name @@ -77,7 +78,7 @@ class SessionLauncherORM(BaseORM): __tablename__ = "launchers" - id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False) + id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False) """Id of this session launcher object.""" name: Mapped[str] = mapped_column("name", String(99)) diff --git a/components/renku_data_services/storage/apispec_base.py b/components/renku_data_services/storage/apispec_base.py index ee344ebf7..9654e3b9b 100644 --- a/components/renku_data_services/storage/apispec_base.py +++ b/components/renku_data_services/storage/apispec_base.py @@ -1,6 +1,7 @@ """Base models for API specifications.""" -from pydantic import BaseModel +from pydantic import BaseModel, field_validator +from ulid import ULID class BaseAPISpec(BaseModel): @@ -10,3 +11,15 @@ class Config: """Enables orm mode for pydantic.""" from_attributes = True + + @field_validator("storage_id", mode="before", check_fields=False) + @classmethod + def serialize_storage_id(cls, storage_id: str | ULID) -> str: + """Custom serializer that can handle ULIDs.""" + return str(storage_id) + + @field_validator("secret_id", mode="before", check_fields=False) + @classmethod + def secret_storage_id(cls, secret_id: str | ULID) -> str: + """Custom serializer that can handle ULIDs.""" + return str(secret_id) diff --git a/components/renku_data_services/storage/blueprints.py b/components/renku_data_services/storage/blueprints.py index b82546111..5be28854e 100644 --- a/components/renku_data_services/storage/blueprints.py +++ b/components/renku_data_services/storage/blueprints.py @@ -304,9 +304,11 @@ def upsert_secrets(self) -> BlueprintFactoryResponse: @authenticate(self.authenticator) async def _upsert_secrets(request: Request, user: base_models.APIUser, storage_id: ULID) -> JSONResponse: + # TODO: use @validate once sanic supports validating json lists body = apispec.CloudStorageSecretPostList.model_validate(request.json) + secrets = [models.CloudStorageSecretUpsert.model_validate(s.model_dump()) for s in body.root] result = await self.storage_v2_repo.upsert_storage_secrets( - storage_id=storage_id, user=user, secrets=body.root + storage_id=storage_id, user=user, secrets=secrets ) return json( apispec.CloudStorageSecretGetList.model_validate(result).model_dump(exclude_none=True, mode="json"), 201 diff --git a/components/renku_data_services/storage/db.py b/components/renku_data_services/storage/db.py index 57071b85e..7f0b8ce41 100644 --- a/components/renku_data_services/storage/db.py +++ b/components/renku_data_services/storage/db.py @@ -17,7 +17,7 @@ from renku_data_services.secrets.core import encrypt_user_secret from renku_data_services.secrets.models import SecretKind from renku_data_services.secrets.orm import SecretORM -from renku_data_services.storage import apispec, models +from renku_data_services.storage import models from renku_data_services.storage import orm as schemas from renku_data_services.users.db import UserRepo @@ -165,7 +165,7 @@ async def delete_storage(self, storage_id: ULID, user: base_models.APIUser) -> N await session.delete(storage[0]) async def upsert_storage_secrets( - self, storage_id: ULID, user: base_models.APIUser, secrets: list[apispec.CloudStorageSecretPost] + self, storage_id: ULID, user: base_models.APIUser, secrets: list[models.CloudStorageSecretUpsert] ) -> list[models.CloudStorageSecret]: """Create/update cloud storage secrets.""" # NOTE: Check that user has proper access to the storage @@ -207,7 +207,7 @@ async def upsert_storage_secrets( storage_secret_orm = schemas.CloudStorageSecretsORM( user_id=cast(str, user.id), - storage_id=str(storage_id), + storage_id=storage_id, name=name, secret_id=secret_orm.id, ) @@ -297,6 +297,6 @@ async def filter_projects_by_access_level( scope = authz_models.Scope.WRITE if minimum_access_level == authz_models.Role.OWNER else authz_models.Scope.READ output = [] for id in project_ids: - if await self.project_authz.has_permission(user, ResourceType.project, id, scope): + if await self.project_authz.has_permission(user, ResourceType.project, ULID.from_str(id), scope): output.append(id) return output diff --git a/components/renku_data_services/storage/models.py b/components/renku_data_services/storage/models.py index f85a78424..f53dbe5f7 100644 --- a/components/renku_data_services/storage/models.py +++ b/components/renku_data_services/storage/models.py @@ -5,6 +5,7 @@ from urllib.parse import ParseResult, urlparse from pydantic import BaseModel, Field, PrivateAttr, model_serializer, model_validator +from ulid import ULID from renku_data_services import errors from renku_data_services.storage.rclone import RCloneValidator @@ -59,7 +60,7 @@ class CloudStorage(BaseModel): configuration: RCloneConfig readonly: bool = Field(default=True) - storage_id: str | None = Field(default=None) + storage_id: ULID | None = Field(default=None) source_path: str = Field() """Path inside the cloud storage. @@ -230,9 +231,9 @@ class CloudStorageSecret(BaseModel): """Cloud storage secret model.""" user_id: str = Field() - storage_id: str = Field() + storage_id: ULID = Field() name: str = Field(min_length=1, max_length=99) - secret_id: str = Field() + secret_id: ULID = Field() @classmethod def from_dict(cls, data: dict) -> "CloudStorageSecret": @@ -240,3 +241,10 @@ def from_dict(cls, data: dict) -> "CloudStorageSecret": return cls( user_id=data["user_id"], storage_id=data["storage_id"], name=data["name"], secret_id=data["secret_id"] ) + + +class CloudStorageSecretUpsert(BaseModel): + """Insert/update storage secret data.""" + + name: str = Field() + value: str = Field() diff --git a/components/renku_data_services/storage/orm.py b/components/renku_data_services/storage/orm.py index 7cf1e9e2b..cfdd8f0cb 100644 --- a/components/renku_data_services/storage/orm.py +++ b/components/renku_data_services/storage/orm.py @@ -11,6 +11,7 @@ from renku_data_services.secrets.orm import SecretORM from renku_data_services.storage import models from renku_data_services.users.orm import UserORM +from renku_data_services.utils.sqlalchemy import ULIDType JSONVariant = JSON().with_variant(JSONB(), "postgresql") @@ -49,8 +50,8 @@ class CloudStorageORM(BaseORM): readonly: Mapped[bool] = mapped_column("readonly", Boolean(), default=True) """Whether this storage should be mounted readonly or not """ - storage_id: Mapped[str] = mapped_column( - "storage_id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False + storage_id: Mapped[ULID] = mapped_column( + "storage_id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False ) """Id of this storage.""" @@ -105,13 +106,13 @@ class CloudStorageSecretsORM(BaseORM): "user_id", ForeignKey(UserORM.keycloak_id, ondelete="CASCADE"), primary_key=True ) - storage_id: Mapped[str] = mapped_column( + storage_id: Mapped[ULID] = mapped_column( "storage_id", ForeignKey(CloudStorageORM.storage_id, ondelete="CASCADE"), primary_key=True ) name: Mapped[str] = mapped_column("name", String(), primary_key=True) - secret_id: Mapped[str] = mapped_column("secret_id", ForeignKey(SecretORM.id, ondelete="CASCADE")) + secret_id: Mapped[ULID] = mapped_column("secret_id", ForeignKey(SecretORM.id, ondelete="CASCADE")) secret: Mapped[SecretORM] = relationship(init=False, repr=False, lazy="selectin") @classmethod diff --git a/test/components/renku_data_services/authz/test_authorization.py b/test/components/renku_data_services/authz/test_authorization.py index 621b801cf..830129be4 100644 --- a/test/components/renku_data_services/authz/test_authorization.py +++ b/test/components/renku_data_services/authz/test_authorization.py @@ -47,7 +47,7 @@ async def test_adding_deleting_project(app_config: Config, bootstrap_admins, pub project_owner = regular_user1 assert project_owner.id authz = app_config.authz - project_id = str(ULID()) + project_id = ULID() project = Project( id=project_id, name=project_id, @@ -97,7 +97,7 @@ async def test_granting_access(app_config: Config, bootstrap_admins, public_proj assert project_owner.id assert regular_user2.id authz = app_config.authz - project_id = str(ULID()) + project_id = ULID() project = Project( id=project_id, name=project_id, @@ -142,7 +142,7 @@ async def test_listing_users_with_access(app_config: Config, public_project: boo assert project_owner.id assert regular_user2.id authz = app_config.authz - project1_id = str(ULID()) + project1_id = ULID() project1 = Project( id=project1_id, name=project1_id, @@ -157,7 +157,7 @@ async def test_listing_users_with_access(app_config: Config, public_project: boo visibility=Visibility.PUBLIC if public_project else Visibility.PRIVATE, created_by=project_owner.id, ) - project2_id = str(ULID()) + project2_id = ULID() project2 = Project( id=project2_id, name=project2_id, @@ -187,9 +187,12 @@ async def test_listing_users_with_access(app_config: Config, public_project: boo @pytest.mark.asyncio async def test_listing_projects_with_access(app_config: Config, bootstrap_admins) -> None: authz = app_config.authz - public_project_id = str(ULID()) - private_project_id1 = str(ULID()) - private_project_id2 = str(ULID()) + public_project_id = ULID() + private_project_id1 = ULID() + private_project_id2 = ULID() + public_project_id_str = str(public_project_id) + private_project_id1_str = str(private_project_id1) + private_project_id2_str = str(private_project_id2) project_owner = regular_user1 namespace = Namespace( project_owner.id, @@ -227,22 +230,22 @@ async def test_listing_projects_with_access(app_config: Config, bootstrap_admins for p in [public_project, private_project1, private_project2]: changes = authz._add_project(p) await authz.client.WriteRelationships(changes.apply) - assert {public_project_id, private_project_id1, private_project_id2} == set( + assert {public_project_id_str, private_project_id1_str, private_project_id2_str} == set( await authz.resources_with_permission(project_owner, regular_user1.id, ResourceType.project, Scope.DELETE) ) - assert {public_project_id, private_project_id1, private_project_id2} == set( + assert {public_project_id_str, private_project_id1_str, private_project_id2_str} == set( await authz.resources_with_permission(project_owner, regular_user1.id, ResourceType.project, Scope.WRITE) ) - assert {public_project_id, private_project_id1, private_project_id2} == set( + assert {public_project_id_str, private_project_id1_str, private_project_id2_str} == set( await authz.resources_with_permission(project_owner, regular_user1.id, ResourceType.project, Scope.READ) ) - assert {public_project_id, private_project_id1, private_project_id2} == set( + assert {public_project_id_str, private_project_id1_str, private_project_id2_str} == set( await authz.resources_with_permission(admin_user, admin_user.id, ResourceType.project, Scope.DELETE) ) - assert {public_project_id, private_project_id1, private_project_id2} == set( + assert {public_project_id_str, private_project_id1_str, private_project_id2_str} == set( await authz.resources_with_permission(admin_user, admin_user.id, ResourceType.project, Scope.WRITE) ) - assert {public_project_id, private_project_id1, private_project_id2} == set( + assert {public_project_id_str, private_project_id1_str, private_project_id2_str} == set( await authz.resources_with_permission(admin_user, admin_user.id, ResourceType.project, Scope.READ) ) with pytest.raises(errors.ForbiddenError): @@ -252,10 +255,10 @@ async def test_listing_projects_with_access(app_config: Config, bootstrap_admins await authz.resources_with_permission(regular_user2, project_owner.id, ResourceType.project, Scope.WRITE) await authz.resources_with_permission(regular_user2, project_owner.id, ResourceType.project, Scope.DELETE) await authz.resources_with_permission(regular_user2, project_owner.id, ResourceType.project, Scope.READ) - assert {public_project_id} == set( + assert {public_project_id_str} == set( await authz.resources_with_permission(anon_user, anon_user.id, ResourceType.project, Scope.READ) ) - assert {public_project_id} == set( + assert {public_project_id_str} == set( await authz.resources_with_permission(regular_user2, regular_user2.id, ResourceType.project, Scope.READ) ) await authz.upsert_project_members( @@ -264,7 +267,7 @@ async def test_listing_projects_with_access(app_config: Config, bootstrap_admins private_project1.id, [Member(Role.VIEWER, regular_user2.id, private_project_id1)], ) - assert {public_project_id, private_project_id1} == set( + assert {public_project_id_str, private_project_id1_str} == set( await authz.resources_with_permission(regular_user2, regular_user2.id, ResourceType.project, Scope.READ) ) assert ( @@ -290,12 +293,12 @@ async def test_listing_projects_with_access(app_config: Config, bootstrap_admins # Test project deletion changes = await authz._remove_project(project_owner, private_project1) await authz.client.WriteRelationships(changes.apply) - assert private_project_id1 not in set( + assert private_project_id1_str not in set( await authz.resources_with_permission(admin_user, project_owner.id, ResourceType.project, Scope.READ) ) - assert private_project_id1 not in set( + assert private_project_id1_str not in set( await authz.resources_with_permission(admin_user, regular_user2.id, ResourceType.project, Scope.READ) ) - assert private_project_id1 not in set( + assert private_project_id1_str not in set( await authz.resources_with_permission(admin_user, admin_user.id, ResourceType.project, Scope.DELETE) ) From 5abb7ddc0b312ecb9e7a74dbd9f91740fd31d35a Mon Sep 17 00:00:00 2001 From: Ralf Grubenmann Date: Wed, 21 Aug 2024 14:55:56 +0200 Subject: [PATCH 5/5] remove unneccessary str cast for ULID in queries --- components/renku_data_services/project/db.py | 19 +++++++++---------- components/renku_data_services/secrets/db.py | 6 +++--- components/renku_data_services/session/db.py | 6 +++--- components/renku_data_services/storage/db.py | 12 ++++++------ 4 files changed, 21 insertions(+), 22 deletions(-) diff --git a/components/renku_data_services/project/db.py b/components/renku_data_services/project/db.py index 701473d11..460e0b374 100644 --- a/components/renku_data_services/project/db.py +++ b/components/renku_data_services/project/db.py @@ -83,7 +83,7 @@ async def get_project(self, user: base_models.APIUser, project_id: ULID) -> mode ) async with self.session_maker() as session: - stmt = select(schemas.ProjectORM).where(schemas.ProjectORM.id == str(project_id)) + stmt = select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id) result = await session.execute(stmt) project_orm = result.scalars().first() @@ -196,10 +196,10 @@ async def update_project( project_id_str: str = str(project_id) if not session: raise errors.ProgrammingError(message="A database session is required") - result = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str)) + result = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) project = result.one_or_none() if project is None: - raise errors.MissingResourceError(message=f"The project with id '{project_id_str}' cannot be found") + raise errors.MissingResourceError(message=f"The project with id '{project_id}' cannot be found") old_project = project.dump() required_scope = Scope.WRITE @@ -215,7 +215,7 @@ async def update_project( authorized = await self.authz.has_permission(user, ResourceType.project, project_id, required_scope) if not authorized: raise errors.MissingResourceError( - message=f"Project with id '{project_id_str}' does not exist or you do not have access to it." + message=f"Project with id '{project_id}' does not exist or you do not have access to it." ) current_etag = project.dump().etag @@ -228,7 +228,7 @@ async def update_project( for r in payload["repositories"] ] # Trigger update for ``updated_at`` column - await session.execute(update(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str).values()) + await session.execute(update(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id).values()) if "keywords" in payload and not payload["keywords"]: payload["keywords"] = None @@ -279,17 +279,16 @@ async def delete_project( message=f"Project with id '{project_id}' does not exist or you do not have access to it." ) - project_id_str = str(project_id) - result = await session.execute(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str)) + result = await session.execute(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) project = result.scalar_one_or_none() if project is None: return None - await session.execute(delete(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str)) + await session.execute(delete(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)) await session.execute( - delete(storage_schemas.CloudStorageORM).where(storage_schemas.CloudStorageORM.project_id == project_id_str) + delete(storage_schemas.CloudStorageORM).where(storage_schemas.CloudStorageORM.project_id == str(project_id)) ) return project.dump() @@ -327,7 +326,7 @@ async def decorated_func( message="The decorator that checks if a project exists requires a database session in the " f"keyword arguments, but instead it got {type(session)}" ) - stmt = select(schemas.ProjectORM.id).where(schemas.ProjectORM.id == str(project_id)) + stmt = select(schemas.ProjectORM.id).where(schemas.ProjectORM.id == project_id) res = await session.scalar(stmt) if not res: raise errors.MissingResourceError( diff --git a/components/renku_data_services/secrets/db.py b/components/renku_data_services/secrets/db.py index dd0e06efe..d5c7f9510 100644 --- a/components/renku_data_services/secrets/db.py +++ b/components/renku_data_services/secrets/db.py @@ -38,7 +38,7 @@ async def get_user_secrets(self, requested_by: APIUser, kind: SecretKind) -> lis async def get_secret_by_id(self, requested_by: APIUser, secret_id: ULID) -> Secret | None: """Get a specific user secret from the database.""" async with self.session_maker() as session: - stmt = select(SecretORM).where(SecretORM.user_id == requested_by.id).where(SecretORM.id == str(secret_id)) + stmt = select(SecretORM).where(SecretORM.user_id == requested_by.id).where(SecretORM.id == secret_id) res = await session.execute(stmt) orm = res.scalar_one_or_none() if orm is None: @@ -89,7 +89,7 @@ async def update_secret( async with self.session_maker() as session, session.begin(): result = await session.execute( - select(SecretORM).where(SecretORM.id == str(secret_id)).where(SecretORM.user_id == requested_by.id) + select(SecretORM).where(SecretORM.id == secret_id).where(SecretORM.user_id == requested_by.id) ) secret = result.scalar_one_or_none() if secret is None: @@ -104,7 +104,7 @@ async def delete_secret(self, requested_by: APIUser, secret_id: ULID) -> None: async with self.session_maker() as session, session.begin(): result = await session.execute( - select(SecretORM).where(SecretORM.id == str(secret_id)).where(SecretORM.user_id == requested_by.id) + select(SecretORM).where(SecretORM.id == secret_id).where(SecretORM.user_id == requested_by.id) ) secret = result.scalar_one_or_none() if secret is None: diff --git a/components/renku_data_services/session/db.py b/components/renku_data_services/session/db.py index 8e488b7f1..b3144dbef 100644 --- a/components/renku_data_services/session/db.py +++ b/components/renku_data_services/session/db.py @@ -154,7 +154,7 @@ async def get_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> mo """Get one session launcher from the database.""" async with self.session_maker() as session: res = await session.scalars( - select(schemas.SessionLauncherORM).where(schemas.SessionLauncherORM.id == str(launcher_id)) + select(schemas.SessionLauncherORM).where(schemas.SessionLauncherORM.id == launcher_id) ) launcher = res.one_or_none() @@ -252,7 +252,7 @@ async def update_launcher( async with self.session_maker() as session, session.begin(): res = await session.scalars( - select(schemas.SessionLauncherORM).where(schemas.SessionLauncherORM.id == str(launcher_id)) + select(schemas.SessionLauncherORM).where(schemas.SessionLauncherORM.id == launcher_id) ) launcher = res.one_or_none() if launcher is None: @@ -330,7 +330,7 @@ async def delete_launcher(self, user: base_models.APIUser, launcher_id: ULID) -> async with self.session_maker() as session, session.begin(): res = await session.scalars( - select(schemas.SessionLauncherORM).where(schemas.SessionLauncherORM.id == str(launcher_id)) + select(schemas.SessionLauncherORM).where(schemas.SessionLauncherORM.id == launcher_id) ) launcher = res.one_or_none() diff --git a/components/renku_data_services/storage/db.py b/components/renku_data_services/storage/db.py index 7f0b8ce41..00fde4a49 100644 --- a/components/renku_data_services/storage/db.py +++ b/components/renku_data_services/storage/db.py @@ -67,7 +67,7 @@ async def get_storage( stmt = select(schemas.CloudStorageORM) if project_id is not None: - stmt = stmt.where(schemas.CloudStorageORM.project_id == str(project_id)) + stmt = stmt.where(schemas.CloudStorageORM.project_id == project_id) if id is not None: stmt = stmt.where(schemas.CloudStorageORM.storage_id == id) if name is not None: @@ -120,7 +120,7 @@ async def update_storage(self, storage_id: ULID, user: base_models.APIUser, **kw """Update a cloud storage entry.""" async with self.session_maker() as session, session.begin(): res = await session.execute( - select(schemas.CloudStorageORM).where(schemas.CloudStorageORM.storage_id == str(storage_id)) + select(schemas.CloudStorageORM).where(schemas.CloudStorageORM.storage_id == storage_id) ) storage = res.scalars().one_or_none() @@ -153,7 +153,7 @@ async def delete_storage(self, storage_id: ULID, user: base_models.APIUser) -> N """Delete a cloud storage entry.""" async with self.session_maker() as session, session.begin(): res = await session.execute( - select(schemas.CloudStorageORM).where(schemas.CloudStorageORM.storage_id == str(storage_id)) + select(schemas.CloudStorageORM).where(schemas.CloudStorageORM.storage_id == storage_id) ) storage = res.one_or_none() @@ -223,7 +223,7 @@ async def get_storage_secrets(self, storage_id: ULID, user: base_models.APIUser) stmt = ( select(schemas.CloudStorageSecretsORM) .where(schemas.CloudStorageSecretsORM.user_id == user.id) - .where(schemas.CloudStorageSecretsORM.storage_id == str(storage_id)) + .where(schemas.CloudStorageSecretsORM.storage_id == storage_id) ) result = await session.execute(stmt) storage_secrets_orm = result.scalars().all() @@ -237,13 +237,13 @@ async def delete_storage_secrets(self, storage_id: ULID, user: base_models.APIUs delete(SecretORM) .where(schemas.CloudStorageSecretsORM.secret_id == SecretORM.id) .where(schemas.CloudStorageSecretsORM.user_id == user.id) - .where(schemas.CloudStorageSecretsORM.storage_id == str(storage_id)) + .where(schemas.CloudStorageSecretsORM.storage_id == storage_id) ) await session.execute(stmt) stmt = ( delete(schemas.CloudStorageSecretsORM) .where(schemas.CloudStorageSecretsORM.user_id == user.id) - .where(schemas.CloudStorageSecretsORM.storage_id == str(storage_id)) + .where(schemas.CloudStorageSecretsORM.storage_id == storage_id) ) await session.execute(stmt)