Skip to content

Commit e7932fc

Browse files
committed
refactor: add sqlalchemy ulid type to all relevant models
1 parent e11a1b8 commit e7932fc

File tree

24 files changed

+126
-67
lines changed

24 files changed

+126
-67
lines changed

bases/renku_data_services/background_jobs/core.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
SubjectFilter,
1212
WriteRelationshipsRequest,
1313
)
14+
from ulid import ULID
1415

1516
from renku_data_services.authz.authz import Authz, ResourceType, _AuthzConverter, _Relation
1617
from renku_data_services.authz.models import Scope
@@ -117,7 +118,7 @@ async def fix_mismatched_project_namespace_ids(config: SyncConfig) -> None:
117118
relation=rel.relationship.relation,
118119
subject=SubjectReference(
119120
object=ObjectReference(
120-
object_type=ResourceType.group.value, object_id=correct_group_id
121+
object_type=ResourceType.group.value, object_id=str(correct_group_id)
121122
)
122123
),
123124
),
@@ -169,7 +170,7 @@ async def migrate_groups_make_all_public(config: SyncConfig) -> None:
169170
all_users = SubjectReference(object=_AuthzConverter.all_users())
170171
all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users())
171172
for group_id in groups_to_process:
172-
group_res = _AuthzConverter.group(group_id)
173+
group_res = _AuthzConverter.group(ULID.from_str(group_id))
173174
all_users_are_viewers = Relationship(
174175
resource=group_res,
175176
relation=_Relation.public_viewer.value,
@@ -228,7 +229,7 @@ async def migrate_user_namespaces_make_all_public(config: SyncConfig) -> None:
228229
all_users = SubjectReference(object=_AuthzConverter.all_users())
229230
all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users())
230231
for ns_id in namespaces_to_process:
231-
namespace_res = _AuthzConverter.user_namespace(ns_id)
232+
namespace_res = _AuthzConverter.user_namespace(ULID.from_str(ns_id))
232233
all_users_are_viewers = Relationship(
233234
resource=namespace_res,
234235
relation=_Relation.public_viewer.value,

components/renku_data_services/authz/authz.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import dataclass, field
66
from enum import StrEnum
77
from functools import wraps
8-
from typing import ClassVar, Concatenate, ParamSpec, Protocol, TypeVar
8+
from typing import ClassVar, Concatenate, ParamSpec, Protocol, TypeVar, cast
99

1010
from authzed.api.v1 import AsyncClient
1111
from authzed.api.v1.core_pb2 import ObjectReference, Relationship, RelationshipUpdate, SubjectReference, ZedToken
@@ -162,12 +162,12 @@ def all_users() -> ObjectReference:
162162
return ObjectReference(object_type=ResourceType.user, object_id="*")
163163

164164
@staticmethod
165-
def group(id: str) -> ObjectReference:
166-
return ObjectReference(object_type=ResourceType.group, object_id=id)
165+
def group(id: ULID) -> ObjectReference:
166+
return ObjectReference(object_type=ResourceType.group, object_id=str(id))
167167

168168
@staticmethod
169-
def user_namespace(id: str) -> ObjectReference:
170-
return ObjectReference(object_type=ResourceType.user_namespace, object_id=id)
169+
def user_namespace(id: ULID) -> ObjectReference:
170+
return ObjectReference(object_type=ResourceType.user_namespace, object_id=str(id))
171171

172172
@staticmethod
173173
def to_object(resource_type: ResourceType, resource_id: str | int) -> ObjectReference:
@@ -178,9 +178,9 @@ def to_object(resource_type: ResourceType, resource_id: str | int) -> ObjectRefe
178178
return _AuthzConverter.user(sid)
179179
case (ResourceType.anonymous_user, _):
180180
return _AuthzConverter.anonymous_users()
181-
case (ResourceType.user_namespace, rid) if isinstance(rid, str):
181+
case (ResourceType.user_namespace, rid) if isinstance(rid, ULID):
182182
return _AuthzConverter.user_namespace(rid)
183-
case (ResourceType.group, rid) if isinstance(rid, str):
183+
case (ResourceType.group, rid) if isinstance(rid, ULID):
184184
return _AuthzConverter.group(rid)
185185
raise errors.ProgrammingError(
186186
message=f"Unexpected or unknown resource type when checking permissions {resource_type}"
@@ -575,7 +575,7 @@ def _add_project(self, project: Project) -> _AuthzChange:
575575
object=(
576576
_AuthzConverter.user_namespace(project.namespace.id)
577577
if project.namespace.kind == NamespaceKind.user
578-
else _AuthzConverter.group(project.namespace.underlying_resource_id)
578+
else _AuthzConverter.group(cast(ULID, project.namespace.underlying_resource_id))
579579
)
580580
)
581581
project_in_platform = Relationship(
@@ -752,10 +752,14 @@ async def _update_project_namespace(
752752
else SubjectReference(object=_AuthzConverter.user_namespace(project.namespace.id))
753753
)
754754
old_namespace_sub = (
755-
SubjectReference(object=_AuthzConverter.group(current_namespace.relationship.subject.object.object_id))
755+
SubjectReference(
756+
object=_AuthzConverter.group(ULID.from_str(current_namespace.relationship.subject.object.object_id))
757+
)
756758
if current_namespace.relationship.subject.object.object_type == ResourceType.group.value
757759
else SubjectReference(
758-
object=_AuthzConverter.user_namespace(current_namespace.relationship.subject.object.object_id)
760+
object=_AuthzConverter.user_namespace(
761+
ULID.from_str(current_namespace.relationship.subject.object.object_id)
762+
)
759763
)
760764
)
761765
new_namespace = Relationship(
@@ -1084,7 +1088,7 @@ async def _remove_group(
10841088
message="Cannot remove a group in the authorization database if the group has no ID"
10851089
)
10861090
consistency = Consistency(at_least_as_fresh=zed_token) if zed_token else Consistency(fully_consistent=True)
1087-
rel_filter = RelationshipFilter(resource_type=ResourceType.group.value, optional_resource_id=group.id)
1091+
rel_filter = RelationshipFilter(resource_type=ResourceType.group.value, optional_resource_id=str(group.id))
10881092
responses = self.client.ReadRelationships(
10891093
ReadRelationshipsRequest(consistency=consistency, relationship_filter=rel_filter)
10901094
)
@@ -1104,7 +1108,7 @@ async def upsert_group_members(
11041108
self,
11051109
user: base_models.APIUser,
11061110
resource_type: ResourceType,
1107-
resource_id: str,
1111+
resource_id: ULID,
11081112
members: list[Member],
11091113
*,
11101114
zed_token: ZedToken | None = None,
@@ -1115,8 +1119,9 @@ async def upsert_group_members(
11151119
add_members: list[RelationshipUpdate] = []
11161120
undo: list[RelationshipUpdate] = []
11171121
output: list[MembershipChange] = []
1122+
resource_id_str = str(resource_id)
11181123
expected_user_roles = {_Relation.viewer.value, _Relation.owner.value, _Relation.editor.value}
1119-
existing_owners_rels = await self._get_resource_owners(resource_type, resource_id, consistency)
1124+
existing_owners_rels = await self._get_resource_owners(resource_type, resource_id_str, consistency)
11201125
n_existing_owners = len(existing_owners_rels)
11211126
for member in members:
11221127
rel = Relationship(
@@ -1126,7 +1131,7 @@ async def upsert_group_members(
11261131
)
11271132
existing_rel_filter = RelationshipFilter(
11281133
resource_type=resource_type.value,
1129-
optional_resource_id=resource_id,
1134+
optional_resource_id=resource_id_str,
11301135
optional_subject_filter=SubjectFilter(
11311136
subject_type=ResourceType.user, optional_subject_id=member.user_id
11321137
),
@@ -1222,7 +1227,7 @@ async def remove_group_members(
12221227
self,
12231228
user: base_models.APIUser,
12241229
resource_type: ResourceType,
1225-
resource_id: str,
1230+
resource_id: ULID,
12261231
user_ids: list[str],
12271232
*,
12281233
zed_token: ZedToken | None = None,
@@ -1233,12 +1238,13 @@ async def remove_group_members(
12331238
remove_members: list[RelationshipUpdate] = []
12341239
output: list[MembershipChange] = []
12351240
existing_owners_rels: list[ReadRelationshipsResponse] | None = None
1241+
resource_id_str = str(resource_id)
12361242
for user_id in user_ids:
12371243
if user_id == "*":
12381244
raise errors.ValidationError(message="Cannot remove a group member with ID '*'")
12391245
existing_rel_filter = RelationshipFilter(
12401246
resource_type=resource_type.value,
1241-
optional_resource_id=resource_id,
1247+
optional_resource_id=resource_id_str,
12421248
optional_subject_filter=SubjectFilter(subject_type=ResourceType.user, optional_subject_id=user_id),
12431249
)
12441250
existing_rels: AsyncIterable[ReadRelationshipsResponse] = self.client.ReadRelationships(
@@ -1249,7 +1255,9 @@ async def remove_group_members(
12491255
async for existing_rel in existing_rels:
12501256
if existing_rel.relationship.relation == _Relation.owner.value:
12511257
if existing_owners_rels is None:
1252-
existing_owners_rels = await self._get_resource_owners(resource_type, resource_id, consistency)
1258+
existing_owners_rels = await self._get_resource_owners(
1259+
resource_type, resource_id_str, consistency
1260+
)
12531261
if len(existing_owners_rels) == 1:
12541262
raise errors.ValidationError(
12551263
message="You are trying to remove the single last owner of the group, "

components/renku_data_services/authz/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from dataclasses import dataclass
44
from enum import Enum
55

6+
from ulid import ULID
7+
68
from renku_data_services.errors import errors
79
from renku_data_services.namespace.apispec import GroupRole
810

@@ -56,7 +58,7 @@ class Member:
5658

5759
role: Role
5860
user_id: str
59-
resource_id: str
61+
resource_id: str | ULID
6062

6163

6264
class Change(Enum):

components/renku_data_services/connected_services/apispec_base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Base models for API specifications."""
22

3-
from pydantic import BaseModel, Field
3+
from pydantic import BaseModel, Field, field_validator
4+
from ulid import ULID
45

56

67
class BaseAPISpec(BaseModel):
@@ -14,6 +15,12 @@ class Config:
1415
# this rust crate does not support lookahead regex syntax but we need it in this component
1516
regex_engine = "python-re"
1617

18+
@field_validator("id", mode="before", check_fields=False)
19+
@classmethod
20+
def serialize_id(cls, id: str | ULID) -> str:
21+
"""Custom serializer that can handle ULIDs."""
22+
return str(id)
23+
1724

1825
class AuthorizeParams(BaseAPISpec):
1926
"""The schema for the query parameters used in the authorize request."""

components/renku_data_services/connected_services/blueprints.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sanic.log import logger
88
from sanic.response import JSONResponse
99
from sanic_ext import validate
10+
from ulid import ULID
1011

1112
import renku_data_services.base_models as base_models
1213
from renku_data_services.base_api.auth import authenticate, only_admins, only_authenticated
@@ -150,34 +151,34 @@ def get_one(self) -> BlueprintFactoryResponse:
150151
"""Get a specific OAuth2 connection."""
151152

152153
@authenticate(self.authenticator)
153-
async def _get_one(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
154+
async def _get_one(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
154155
connection = await self.connected_services_repo.get_oauth2_connection(
155156
connection_id=connection_id, user=user
156157
)
157158
return validated_json(apispec.Connection, connection)
158159

159-
return "/oauth2/connections/<connection_id>", ["GET"], _get_one
160+
return "/oauth2/connections/<connection_id:ulid>", ["GET"], _get_one
160161

161162
def get_account(self) -> BlueprintFactoryResponse:
162163
"""Get the account information for a specific OAuth2 connection."""
163164

164165
@authenticate(self.authenticator)
165-
async def _get_account(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
166+
async def _get_account(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
166167
account = await self.connected_services_repo.get_oauth2_connected_account(
167168
connection_id=connection_id, user=user
168169
)
169170
return validated_json(apispec.ConnectedAccount, account)
170171

171-
return "/oauth2/connections/<connection_id>/account", ["GET"], _get_account
172+
return "/oauth2/connections/<connection_id:ulid>/account", ["GET"], _get_account
172173

173174
def get_token(self) -> BlueprintFactoryResponse:
174175
"""Get the access token for a specific OAuth2 connection."""
175176

176177
@authenticate(self.authenticator)
177-
async def _get_token(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
178+
async def _get_token(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
178179
token = await self.connected_services_repo.get_oauth2_connection_token(
179180
connection_id=connection_id, user=user
180181
)
181182
return json(token.dump_for_api())
182183

183-
return "/oauth2/connections/<connection_id>/token", ["GET"], _get_token
184+
return "/oauth2/connections/<connection_id:ulid>/token", ["GET"], _get_token

components/renku_data_services/connected_services/db.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from sqlalchemy import select
1212
from sqlalchemy.ext.asyncio import AsyncSession
1313
from sqlalchemy.orm import selectinload
14+
from ulid import ULID
1415

1516
import renku_data_services.base_models as base_models
1617
from renku_data_services import errors
@@ -282,7 +283,7 @@ async def get_oauth2_connections(
282283
connections = result.all()
283284
return [c.dump() for c in connections]
284285

285-
async def get_oauth2_connection(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2Connection:
286+
async def get_oauth2_connection(self, connection_id: ULID, user: base_models.APIUser) -> models.OAuth2Connection:
286287
"""Get one OAuth2 connection from the database."""
287288
if not user.is_authenticated or user.id is None:
288289
raise errors.MissingResourceError(
@@ -303,7 +304,7 @@ async def get_oauth2_connection(self, connection_id: str, user: base_models.APIU
303304
return connection.dump()
304305

305306
async def get_oauth2_connected_account(
306-
self, connection_id: str, user: base_models.APIUser
307+
self, connection_id: ULID, user: base_models.APIUser
307308
) -> models.ConnectedAccount:
308309
"""Get the account information from a OAuth2 connection."""
309310
async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, adapter):
@@ -316,7 +317,9 @@ async def get_oauth2_connected_account(
316317
account = adapter.api_validate_account_response(response)
317318
return account
318319

319-
async def get_oauth2_connection_token(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2TokenSet:
320+
async def get_oauth2_connection_token(
321+
self, connection_id: ULID, user: base_models.APIUser
322+
) -> models.OAuth2TokenSet:
320323
"""Get the OAuth2 access token from one connection from the database."""
321324
async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, _):
322325
await oauth2_client.ensure_active_token(oauth2_client.token)
@@ -325,7 +328,7 @@ async def get_oauth2_connection_token(self, connection_id: str, user: base_model
325328

326329
@asynccontextmanager
327330
async def get_async_oauth2_client(
328-
self, connection_id: str, user: base_models.APIUser
331+
self, connection_id: ULID, user: base_models.APIUser
329332
) -> AsyncGenerator[tuple[AsyncOAuth2Client, schemas.OAuth2ConnectionORM, ProviderAdapter], None]:
330333
"""Get the AsyncOAuth2Client for the given connection_id and user."""
331334
if not user.is_authenticated or user.id is None:

components/renku_data_services/connected_services/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from datetime import UTC, datetime
55
from typing import Any
66

7+
from ulid import ULID
8+
79
from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind
810

911

@@ -28,7 +30,7 @@ class OAuth2Client:
2830
class OAuth2Connection:
2931
"""OAuth2 connection model."""
3032

31-
id: str
33+
id: ULID
3234
provider_id: str
3335
status: ConnectionStatus
3436

components/renku_data_services/connected_services/orm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from renku_data_services.connected_services import models
1313
from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind
14+
from renku_data_services.utils.sqlalchemy import ULIDType
1415

1516
JSONVariant = JSON().with_variant(JSONB(), "postgresql")
1617

@@ -72,7 +73,7 @@ class OAuth2ConnectionORM(BaseORM):
7273
"""An OAuth2 connection."""
7374

7475
__tablename__ = "oauth2_connections"
75-
id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False)
76+
id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
7677
user_id: Mapped[str] = mapped_column("user_id", String())
7778
client_id: Mapped[str] = mapped_column(ForeignKey(OAuth2ClientORM.id, ondelete="CASCADE"), index=True)
7879
client: Mapped[OAuth2ClientORM] = relationship(init=False, repr=False)

0 commit comments

Comments
 (0)