Skip to content

Commit 9e6223b

Browse files
committed
refactor: add custom ULID type for sqlalchemy
1 parent 8afb949 commit 9e6223b

File tree

5 files changed

+55
-23
lines changed

5 files changed

+55
-23
lines changed

components/renku_data_services/project/db.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from sqlalchemy import Select, delete, func, select, update
1212
from sqlalchemy.ext.asyncio import AsyncSession
13+
from ulid import ULID
1314

1415
import renku_data_services.base_models as base_models
1516
from renku_data_services import errors
@@ -73,7 +74,7 @@ async def get_projects(
7374
total_elements = results[1].scalar() or 0
7475
return [p.dump() for p in projects_orm], total_elements
7576

76-
async def get_project(self, user: base_models.APIUser, project_id: str) -> models.Project:
77+
async def get_project(self, user: base_models.APIUser, project_id: ULID) -> models.Project:
7778
"""Get one project from the database."""
7879
authorized = await self.authz.has_permission(user, ResourceType.project, project_id, Scope.READ)
7980
if not authorized:
@@ -82,7 +83,7 @@ async def get_project(self, user: base_models.APIUser, project_id: str) -> model
8283
)
8384

8485
async with self.session_maker() as session:
85-
stmt = select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id)
86+
stmt = select(schemas.ProjectORM).where(schemas.ProjectORM.id == str(project_id))
8687
result = await session.execute(stmt)
8788
project_orm = result.scalars().first()
8889

@@ -110,7 +111,10 @@ async def get_project_by_namespace_slug(
110111
raise errors.MissingResourceError(message=not_found_msg)
111112

112113
authorized = await self.authz.has_permission(
113-
user=user, resource_type=ResourceType.project, resource_id=project_orm.id, scope=Scope.READ
114+
user=user,
115+
resource_type=ResourceType.project,
116+
resource_id=project_orm.id,
117+
scope=Scope.READ,
114118
)
115119
if not authorized:
116120
raise errors.MissingResourceError(message=not_found_msg)
@@ -167,7 +171,7 @@ async def insert_project(
167171
creation_date=datetime.now(UTC).replace(microsecond=0),
168172
keywords=project.keywords,
169173
)
170-
project_slug = schemas.ProjectSlug(slug, project_id=project_orm.id, namespace_id=ns.id)
174+
project_slug = schemas.ProjectSlug(slug, project_id=str(project_orm.id), namespace_id=ns.id)
171175

172176
session.add(project_slug)
173177
session.add(project_orm)
@@ -182,19 +186,20 @@ async def insert_project(
182186
async def update_project(
183187
self,
184188
user: base_models.APIUser,
185-
project_id: str,
189+
project_id: ULID,
186190
payload: dict[str, Any],
187191
etag: str | None = None,
188192
*,
189193
session: AsyncSession | None = None,
190194
) -> models.ProjectUpdate:
191195
"""Update a project entry."""
196+
project_id_str: str = str(project_id)
192197
if not session:
193198
raise errors.ProgrammingError(message="A database session is required")
194-
result = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id))
199+
result = await session.scalars(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str))
195200
project = result.one_or_none()
196201
if project is None:
197-
raise errors.MissingResourceError(message=f"The project with id '{project_id}' cannot be found")
202+
raise errors.MissingResourceError(message=f"The project with id '{project_id_str}' cannot be found")
198203
old_project = project.dump()
199204

200205
required_scope = Scope.WRITE
@@ -210,7 +215,7 @@ async def update_project(
210215
authorized = await self.authz.has_permission(user, ResourceType.project, project_id, required_scope)
211216
if not authorized:
212217
raise errors.MissingResourceError(
213-
message=f"Project with id '{project_id}' does not exist or you do not have access to it."
218+
message=f"Project with id '{project_id_str}' does not exist or you do not have access to it."
214219
)
215220

216221
current_etag = project.dump().etag
@@ -219,11 +224,11 @@ async def update_project(
219224

220225
if "repositories" in payload:
221226
payload["repositories"] = [
222-
schemas.ProjectRepositoryORM(url=r, project_id=project_id, project=project)
227+
schemas.ProjectRepositoryORM(url=r, project_id=project_id_str, project=project)
223228
for r in payload["repositories"]
224229
]
225230
# Trigger update for ``updated_at`` column
226-
await session.execute(update(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id).values())
231+
await session.execute(update(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str).values())
227232

228233
if "keywords" in payload and not payload["keywords"]:
229234
payload["keywords"] = None
@@ -263,7 +268,7 @@ async def update_project(
263268
@Authz.authz_change(AuthzOperation.delete, ResourceType.project)
264269
@dispatch_message(avro_schema_v2.ProjectRemoved)
265270
async def delete_project(
266-
self, user: base_models.APIUser, project_id: str, *, session: AsyncSession | None = None
271+
self, user: base_models.APIUser, project_id: ULID, *, session: AsyncSession | None = None
267272
) -> models.Project | None:
268273
"""Delete a project."""
269274
if not session:
@@ -274,16 +279,17 @@ async def delete_project(
274279
message=f"Project with id '{project_id}' does not exist or you do not have access to it."
275280
)
276281

277-
result = await session.execute(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id))
282+
project_id_str = str(project_id)
283+
result = await session.execute(select(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str))
278284
project = result.scalar_one_or_none()
279285

280286
if project is None:
281287
return None
282288

283-
await session.execute(delete(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id))
289+
await session.execute(delete(schemas.ProjectORM).where(schemas.ProjectORM.id == project_id_str))
284290

285291
await session.execute(
286-
delete(storage_schemas.CloudStorageORM).where(storage_schemas.CloudStorageORM.project_id == project_id)
292+
delete(storage_schemas.CloudStorageORM).where(storage_schemas.CloudStorageORM.project_id == project_id_str)
287293
)
288294

289295
return project.dump()
@@ -303,15 +309,15 @@ def _filter_by_namespace_slug(statement: Select[tuple[_T]], namespace: str) -> S
303309

304310

305311
def _project_exists(
306-
f: Callable[Concatenate[ProjectMemberRepository, base_models.APIUser, str, _P], Awaitable[_T]],
307-
) -> Callable[Concatenate[ProjectMemberRepository, base_models.APIUser, str, _P], Awaitable[_T]]:
312+
f: Callable[Concatenate[ProjectMemberRepository, base_models.APIUser, ULID, _P], Awaitable[_T]],
313+
) -> Callable[Concatenate[ProjectMemberRepository, base_models.APIUser, ULID, _P], Awaitable[_T]]:
308314
"""Checks if the project exists when adding or modifying project members."""
309315

310316
@functools.wraps(f)
311317
async def decorated_func(
312318
self: ProjectMemberRepository,
313319
user: base_models.APIUser,
314-
project_id: str,
320+
project_id: ULID,
315321
*args: _P.args,
316322
**kwargs: _P.kwargs,
317323
) -> _T:
@@ -321,7 +327,7 @@ async def decorated_func(
321327
message="The decorator that checks if a project exists requires a database session in the "
322328
f"keyword arguments, but instead it got {type(session)}"
323329
)
324-
stmt = select(schemas.ProjectORM.id).where(schemas.ProjectORM.id == project_id)
330+
stmt = select(schemas.ProjectORM.id).where(schemas.ProjectORM.id == str(project_id))
325331
res = await session.scalar(stmt)
326332
if not res:
327333
raise errors.MissingResourceError(
@@ -350,7 +356,7 @@ def __init__(
350356
@with_db_transaction
351357
@_project_exists
352358
async def get_members(
353-
self, user: base_models.APIUser, project_id: str, *, session: AsyncSession | None = None
359+
self, user: base_models.APIUser, project_id: ULID, *, session: AsyncSession | None = None
354360
) -> list[Member]:
355361
"""Get all members of a project."""
356362
members = await self.authz.members(user, ResourceType.project, project_id)
@@ -363,7 +369,7 @@ async def get_members(
363369
async def update_members(
364370
self,
365371
user: base_models.APIUser,
366-
project_id: str,
372+
project_id: ULID,
367373
members: list[Member],
368374
*,
369375
session: AsyncSession | None = None,
@@ -392,7 +398,7 @@ async def update_members(
392398
@_project_exists
393399
@dispatch_message(events.ProjectMembershipChanged)
394400
async def delete_members(
395-
self, user: base_models.APIUser, project_id: str, user_ids: list[str], *, session: AsyncSession | None = None
401+
self, user: base_models.APIUser, project_id: ULID, user_ids: list[str], *, session: AsyncSession | None = None
396402
) -> list[MembershipChange]:
397403
"""Delete members from a project."""
398404
if len(user_ids) == 0:

components/renku_data_services/project/orm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from renku_data_services.namespace.orm import NamespaceORM
1414
from renku_data_services.project import models
1515
from renku_data_services.project.apispec import Visibility
16+
from renku_data_services.utils.sqlalchemy import ULIDType
1617

1718
metadata_obj = MetaData(schema="projects") # Has to match alembic ini section name
1819

@@ -27,7 +28,7 @@ class ProjectORM(BaseORM):
2728
"""A Renku native project."""
2829

2930
__tablename__ = "projects"
30-
id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False)
31+
id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
3132
name: Mapped[str] = mapped_column("name", String(99))
3233
visibility: Mapped[Visibility]
3334
created_by_id: Mapped[str] = mapped_column("created_by_id", String())

components/renku_data_services/session/orm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class SessionLauncherORM(BaseORM):
104104
project: Mapped[ProjectORM] = relationship(init=False)
105105
environment: Mapped[EnvironmentORM | None] = relationship(init=False)
106106

107-
project_id: Mapped[str] = mapped_column(
107+
project_id: Mapped[ULID] = mapped_column(
108108
"project_id", ForeignKey(ProjectORM.id, ondelete="CASCADE"), default=None, index=True
109109
)
110110
"""Id of the project this session belongs to."""

components/renku_data_services/utils/py.typed

Whitespace-only changes.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""Utilities for SQLAlchemy."""
2+
3+
from typing import cast
4+
5+
from sqlalchemy import Dialect, types
6+
from ulid import ULID
7+
8+
9+
class ULIDType(types.TypeDecorator):
10+
"""Wrapper type for ULID <--> str conversion."""
11+
12+
impl = types.String
13+
cache_ok = True
14+
15+
def process_bind_param(self, value: ULID | None, dialect: Dialect) -> str | None:
16+
"""Transform value for storing in the database."""
17+
if value is None:
18+
return None
19+
return str(value)
20+
21+
def process_result_value(self, value: str | None, dialect: Dialect) -> ULID | None:
22+
"""Transform string from database into ULID."""
23+
if value is None:
24+
return None
25+
return cast(ULID, ULID.from_str(value))

0 commit comments

Comments
 (0)