Skip to content

Commit ea98fdd

Browse files
committed
refactor: add sqlalchemy ulid type to all relevant models
1 parent 8a9ed32 commit ea98fdd

File tree

24 files changed

+126
-67
lines changed

24 files changed

+126
-67
lines changed

bases/renku_data_services/background_jobs/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
SubjectFilter,
1212
WriteRelationshipsRequest,
1313
)
14+
from ulid import ULID
1415

1516
from renku_data_services.authz.authz import Authz, ResourceType, _AuthzConverter, _Relation
1617
from renku_data_services.authz.models import Scope
@@ -117,7 +118,7 @@ async def fix_mismatched_project_namespace_ids(config: SyncConfig) -> None:
117118
relation=rel.relationship.relation,
118119
subject=SubjectReference(
119120
object=ObjectReference(
120-
object_type=ResourceType.group.value, object_id=correct_group_id
121+
object_type=ResourceType.group.value, object_id=str(correct_group_id)
121122
)
122123
),
123124
),
@@ -169,7 +170,7 @@ async def migrate_groups_make_all_public(config: SyncConfig) -> None:
169170
all_users = SubjectReference(object=_AuthzConverter.all_users())
170171
all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users())
171172
for group_id in groups_to_process:
172-
group_res = _AuthzConverter.group(group_id)
173+
group_res = _AuthzConverter.group(ULID.from_str(group_id))
173174
all_users_are_viewers = Relationship(
174175
resource=group_res,
175176
relation=_Relation.public_viewer.value,
@@ -228,7 +229,7 @@ async def migrate_user_namespaces_make_all_public(config: SyncConfig) -> None:
228229
all_users = SubjectReference(object=_AuthzConverter.all_users())
229230
all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users())
230231
for ns_id in namespaces_to_process:
231-
namespace_res = _AuthzConverter.user_namespace(ns_id)
232+
namespace_res = _AuthzConverter.user_namespace(ULID.from_str(ns_id))
232233
all_users_are_viewers = Relationship(
233234
resource=namespace_res,
234235
relation=_Relation.public_viewer.value,

components/renku_data_services/authz/authz.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66
from enum import StrEnum
77
from functools import wraps
8-
from typing import ClassVar, Concatenate, ParamSpec, Protocol, TypeVar
8+
from typing import ClassVar, Concatenate, ParamSpec, Protocol, TypeVar, cast
99

1010
from authzed.api.v1 import AsyncClient
1111
from authzed.api.v1.core_pb2 import ObjectReference, Relationship, RelationshipUpdate, SubjectReference, ZedToken
@@ -163,12 +163,12 @@ def all_users() -> ObjectReference:
163163
return ObjectReference(object_type=ResourceType.user, object_id="*")
164164

165165
@staticmethod
166-
def group(id: str) -> ObjectReference:
167-
return ObjectReference(object_type=ResourceType.group, object_id=id)
166+
def group(id: ULID) -> ObjectReference:
167+
return ObjectReference(object_type=ResourceType.group, object_id=str(id))
168168

169169
@staticmethod
170-
def user_namespace(id: str) -> ObjectReference:
171-
return ObjectReference(object_type=ResourceType.user_namespace, object_id=id)
170+
def user_namespace(id: ULID) -> ObjectReference:
171+
return ObjectReference(object_type=ResourceType.user_namespace, object_id=str(id))
172172

173173
@staticmethod
174174
def to_object(resource_type: ResourceType, resource_id: str | ULID | int) -> ObjectReference:
@@ -179,9 +179,9 @@ def to_object(resource_type: ResourceType, resource_id: str | ULID | int) -> Obj
179179
return _AuthzConverter.user(sid)
180180
case (ResourceType.anonymous_user, _):
181181
return _AuthzConverter.anonymous_users()
182-
case (ResourceType.user_namespace, rid) if isinstance(rid, str):
182+
case (ResourceType.user_namespace, rid) if isinstance(rid, ULID):
183183
return _AuthzConverter.user_namespace(rid)
184-
case (ResourceType.group, rid) if isinstance(rid, str):
184+
case (ResourceType.group, rid) if isinstance(rid, ULID):
185185
return _AuthzConverter.group(rid)
186186
raise errors.ProgrammingError(
187187
message=f"Unexpected or unknown resource type when checking permissions {resource_type}"
@@ -580,7 +580,7 @@ def _add_project(self, project: Project) -> _AuthzChange:
580580
object=(
581581
_AuthzConverter.user_namespace(project.namespace.id)
582582
if project.namespace.kind == NamespaceKind.user
583-
else _AuthzConverter.group(project.namespace.underlying_resource_id)
583+
else _AuthzConverter.group(cast(ULID, project.namespace.underlying_resource_id))
584584
)
585585
)
586586
project_in_platform = Relationship(
@@ -759,10 +759,14 @@ async def _update_project_namespace(
759759
else SubjectReference(object=_AuthzConverter.user_namespace(project.namespace.id))
760760
)
761761
old_namespace_sub = (
762-
SubjectReference(object=_AuthzConverter.group(current_namespace.relationship.subject.object.object_id))
762+
SubjectReference(
763+
object=_AuthzConverter.group(ULID.from_str(current_namespace.relationship.subject.object.object_id))
764+
)
763765
if current_namespace.relationship.subject.object.object_type == ResourceType.group.value
764766
else SubjectReference(
765-
object=_AuthzConverter.user_namespace(current_namespace.relationship.subject.object.object_id)
767+
object=_AuthzConverter.user_namespace(
768+
ULID.from_str(current_namespace.relationship.subject.object.object_id)
769+
)
766770
)
767771
)
768772
new_namespace = Relationship(
@@ -1093,7 +1097,7 @@ async def _remove_group(
10931097
message="Cannot remove a group in the authorization database if the group has no ID"
10941098
)
10951099
consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True)
1096-
rel_filter = RelationshipFilter(resource_type=ResourceType.group.value, optional_resource_id=group.id)
1100+
rel_filter = RelationshipFilter(resource_type=ResourceType.group.value, optional_resource_id=str(group.id))
10971101
responses = self.client.ReadRelationships(
10981102
ReadRelationshipsRequest(consistency=consistency, relationship_filter=rel_filter)
10991103
)
@@ -1113,7 +1117,7 @@ async def upsert_group_members(
11131117
self,
11141118
user: base_models.APIUser,
11151119
resource_type: ResourceType,
1116-
resource_id: str,
1120+
resource_id: ULID,
11171121
members: list[Member],
11181122
*,
11191123
zed_token: ZedToken | None = None,
@@ -1124,8 +1128,9 @@ async def upsert_group_members(
11241128
add_members: list[RelationshipUpdate] = []
11251129
undo: list[RelationshipUpdate] = []
11261130
output: list[MembershipChange] = []
1131+
resource_id_str = str(resource_id)
11271132
expected_user_roles = {_Relation.viewer.value, _Relation.owner.value, _Relation.editor.value}
1128-
existing_owners_rels = await self._get_resource_owners(resource_type, resource_id, consistency)
1133+
existing_owners_rels = await self._get_resource_owners(resource_type, resource_id_str, consistency)
11291134
n_existing_owners = len(existing_owners_rels)
11301135
for member in members:
11311136
rel = Relationship(
@@ -1135,7 +1140,7 @@ async def upsert_group_members(
11351140
)
11361141
existing_rel_filter = RelationshipFilter(
11371142
resource_type=resource_type.value,
1138-
optional_resource_id=resource_id,
1143+
optional_resource_id=resource_id_str,
11391144
optional_subject_filter=SubjectFilter(
11401145
subject_type=ResourceType.user, optional_subject_id=member.user_id
11411146
),
@@ -1231,7 +1236,7 @@ async def remove_group_members(
12311236
self,
12321237
user: base_models.APIUser,
12331238
resource_type: ResourceType,
1234-
resource_id: str,
1239+
resource_id: ULID,
12351240
user_ids: list[str],
12361241
*,
12371242
zed_token: ZedToken | None = None,
@@ -1242,12 +1247,13 @@ async def remove_group_members(
12421247
remove_members: list[RelationshipUpdate] = []
12431248
output: list[MembershipChange] = []
12441249
existing_owners_rels: list[ReadRelationshipsResponse] | None = None
1250+
resource_id_str = str(resource_id)
12451251
for user_id in user_ids:
12461252
if user_id == "*":
12471253
raise errors.ValidationError(message="Cannot remove a group member with ID '*'")
12481254
existing_rel_filter = RelationshipFilter(
12491255
resource_type=resource_type.value,
1250-
optional_resource_id=resource_id,
1256+
optional_resource_id=resource_id_str,
12511257
optional_subject_filter=SubjectFilter(subject_type=ResourceType.user, optional_subject_id=user_id),
12521258
)
12531259
existing_rels: AsyncIterable[ReadRelationshipsResponse] = self.client.ReadRelationships(
@@ -1258,7 +1264,9 @@ async def remove_group_members(
12581264
async for existing_rel in existing_rels:
12591265
if existing_rel.relationship.relation == _Relation.owner.value:
12601266
if existing_owners_rels is None:
1261-
existing_owners_rels = await self._get_resource_owners(resource_type, resource_id, consistency)
1267+
existing_owners_rels = await self._get_resource_owners(
1268+
resource_type, resource_id_str, consistency
1269+
)
12621270
if len(existing_owners_rels) == 1:
12631271
raise errors.ValidationError(
12641272
message="You are trying to remove the single last owner of the group, "

components/renku_data_services/authz/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass
44
from enum import Enum
55

6+
from ulid import ULID
7+
68
from renku_data_services.errors import errors
79
from renku_data_services.namespace.apispec import GroupRole
810

@@ -56,7 +58,7 @@ class Member:
5658

5759
role: Role
5860
user_id: str
59-
resource_id: str
61+
resource_id: str | ULID
6062

6163

6264
class Change(Enum):

components/renku_data_services/connected_services/apispec_base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Base models for API specifications."""
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, field_validator
4+
from ulid import ULID
45

56

67
class BaseAPISpec(BaseModel):
@@ -14,6 +15,12 @@ class Config:
1415
# this rust crate does not support lookahead regex syntax but we need it in this component
1516
regex_engine = "python-re"
1617

18+
@field_validator("id", mode="before", check_fields=False)
19+
@classmethod
20+
def serialize_id(cls, id: str | ULID) -> str:
21+
"""Custom serializer that can handle ULIDs."""
22+
return str(id)
23+
1724

1825
class AuthorizeParams(BaseAPISpec):
1926
"""The schema for the query parameters used in the authorize request."""

components/renku_data_services/connected_services/blueprints.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sanic.log import logger
88
from sanic.response import JSONResponse
99
from sanic_ext import validate
10+
from ulid import ULID
1011

1112
import renku_data_services.base_models as base_models
1213
from renku_data_services.base_api.auth import authenticate, only_admins, only_authenticated
@@ -150,34 +151,34 @@ def get_one(self) -> BlueprintFactoryResponse:
150151
"""Get a specific OAuth2 connection."""
151152

152153
@authenticate(self.authenticator)
153-
async def _get_one(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
154+
async def _get_one(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
154155
connection = await self.connected_services_repo.get_oauth2_connection(
155156
connection_id=connection_id, user=user
156157
)
157158
return validated_json(apispec.Connection, connection)
158159

159-
return "/oauth2/connections/<connection_id>", ["GET"], _get_one
160+
return "/oauth2/connections/<connection_id:ulid>", ["GET"], _get_one
160161

161162
def get_account(self) -> BlueprintFactoryResponse:
162163
"""Get the account information for a specific OAuth2 connection."""
163164

164165
@authenticate(self.authenticator)
165-
async def _get_account(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
166+
async def _get_account(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
166167
account = await self.connected_services_repo.get_oauth2_connected_account(
167168
connection_id=connection_id, user=user
168169
)
169170
return validated_json(apispec.ConnectedAccount, account)
170171

171-
return "/oauth2/connections/<connection_id>/account", ["GET"], _get_account
172+
return "/oauth2/connections/<connection_id:ulid>/account", ["GET"], _get_account
172173

173174
def get_token(self) -> BlueprintFactoryResponse:
174175
"""Get the access token for a specific OAuth2 connection."""
175176

176177
@authenticate(self.authenticator)
177-
async def _get_token(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
178+
async def _get_token(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
178179
token = await self.connected_services_repo.get_oauth2_connection_token(
179180
connection_id=connection_id, user=user
180181
)
181182
return json(token.dump_for_api())
182183

183-
return "/oauth2/connections/<connection_id>/token", ["GET"], _get_token
184+
return "/oauth2/connections/<connection_id:ulid>/token", ["GET"], _get_token

components/renku_data_services/connected_services/db.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlalchemy import select
1212
from sqlalchemy.ext.asyncio import AsyncSession
1313
from sqlalchemy.orm import selectinload
14+
from ulid import ULID
1415

1516
import renku_data_services.base_models as base_models
1617
from renku_data_services import errors
@@ -282,7 +283,7 @@ async def get_oauth2_connections(
282283
connections = result.all()
283284
return [c.dump() for c in connections]
284285

285-
async def get_oauth2_connection(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2Connection:
286+
async def get_oauth2_connection(self, connection_id: ULID, user: base_models.APIUser) -> models.OAuth2Connection:
286287
"""Get one OAuth2 connection from the database."""
287288
if not user.is_authenticated or user.id is None:
288289
raise errors.MissingResourceError(
@@ -303,7 +304,7 @@ async def get_oauth2_connection(self, connection_id: str, user: base_models.APIU
303304
return connection.dump()
304305

305306
async def get_oauth2_connected_account(
306-
self, connection_id: str, user: base_models.APIUser
307+
self, connection_id: ULID, user: base_models.APIUser
307308
) -> models.ConnectedAccount:
308309
"""Get the account information from a OAuth2 connection."""
309310
async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, adapter):
@@ -316,7 +317,9 @@ async def get_oauth2_connected_account(
316317
account = adapter.api_validate_account_response(response)
317318
return account
318319

319-
async def get_oauth2_connection_token(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2TokenSet:
320+
async def get_oauth2_connection_token(
321+
self, connection_id: ULID, user: base_models.APIUser
322+
) -> models.OAuth2TokenSet:
320323
"""Get the OAuth2 access token from one connection from the database."""
321324
async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, _):
322325
await oauth2_client.ensure_active_token(oauth2_client.token)
@@ -325,7 +328,7 @@ async def get_oauth2_connection_token(self, connection_id: str, user: base_model
325328

326329
@asynccontextmanager
327330
async def get_async_oauth2_client(
328-
self, connection_id: str, user: base_models.APIUser
331+
self, connection_id: ULID, user: base_models.APIUser
329332
) -> AsyncGenerator[tuple[AsyncOAuth2Client, schemas.OAuth2ConnectionORM, ProviderAdapter], None]:
330333
"""Get the AsyncOAuth2Client for the given connection_id and user."""
331334
if not user.is_authenticated or user.id is None:

components/renku_data_services/connected_services/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from datetime import UTC, datetime
55
from typing import Any
66

7+
from ulid import ULID
8+
79
from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind
810

911

@@ -28,7 +30,7 @@ class OAuth2Client:
2830
class OAuth2Connection:
2931
"""OAuth2 connection model."""
3032

31-
id: str
33+
id: ULID
3234
provider_id: str
3335
status: ConnectionStatus
3436

components/renku_data_services/connected_services/orm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from renku_data_services.connected_services import models
1313
from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind
14+
from renku_data_services.utils.sqlalchemy import ULIDType
1415

1516
JSONVariant = JSON().with_variant(JSONB(), "postgresql")
1617

@@ -72,7 +73,7 @@ class OAuth2ConnectionORM(BaseORM):
7273
"""An OAuth2 connection."""
7374

7475
__tablename__ = "oauth2_connections"
75-
id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False)
76+
id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
7677
user_id: Mapped[str] = mapped_column("user_id", String())
7778
client_id: Mapped[str] = mapped_column(ForeignKey(OAuth2ClientORM.id, ondelete="CASCADE"), index=True)
7879
client: Mapped[OAuth2ClientORM] = relationship(init=False, repr=False)

0 commit comments

Comments
 (0)