Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions bases/renku_data_services/background_jobs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
SubjectFilter,
WriteRelationshipsRequest,
)
from ulid import ULID

from renku_data_services.authz.authz import Authz, ResourceType, _AuthzConverter, _Relation
from renku_data_services.authz.models import Scope
Expand Down Expand Up @@ -133,7 +134,7 @@ async def fix_mismatched_project_namespace_ids(config: SyncConfig) -> None:
relation=rel.relationship.relation,
subject=SubjectReference(
object=ObjectReference(
object_type=ResourceType.group.value, object_id=correct_group_id
object_type=ResourceType.group.value, object_id=str(correct_group_id)
)
),
),
Expand Down Expand Up @@ -185,7 +186,7 @@ async def migrate_groups_make_all_public(config: SyncConfig) -> None:
all_users = SubjectReference(object=_AuthzConverter.all_users())
all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users())
for group_id in groups_to_process:
group_res = _AuthzConverter.group(group_id)
group_res = _AuthzConverter.group(ULID.from_str(group_id))
all_users_are_viewers = Relationship(
resource=group_res,
relation=_Relation.public_viewer.value,
Expand Down Expand Up @@ -244,7 +245,7 @@ async def migrate_user_namespaces_make_all_public(config: SyncConfig) -> None:
all_users = SubjectReference(object=_AuthzConverter.all_users())
all_anon_users = SubjectReference(object=_AuthzConverter.anonymous_users())
for ns_id in namespaces_to_process:
namespace_res = _AuthzConverter.user_namespace(ns_id)
namespace_res = _AuthzConverter.user_namespace(ULID.from_str(ns_id))
all_users_are_viewers = Relationship(
resource=namespace_res,
relation=_Relation.public_viewer.value,
Expand Down
2 changes: 1 addition & 1 deletion bases/renku_data_services/data_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

def register_all_handlers(app: Sanic, config: Config) -> Sanic:
"""Register all handlers on the application."""
app.router.register_pattern("ulid", ULID.from_str, r"^[0-9A-HJKMNP-TV-Z]{26}$")
app.router.register_pattern("ulid", ULID.from_str, r"^[0-7][0-9A-HJKMNP-TV-Z]{25}$")
app.router.register_pattern("renku_slug", str, r"^[a-zA-Z0-9][a-zA-Z0-9\-_.]*$")

url_prefix = "/api/data"
Expand Down
100 changes: 58 additions & 42 deletions components/renku_data_services/authz/authz.py

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion components/renku_data_services/authz/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from enum import Enum

from ulid import ULID

from renku_data_services.errors import errors
from renku_data_services.namespace.apispec import GroupRole

Expand Down Expand Up @@ -56,7 +58,7 @@ class Member:

role: Role
user_id: str
resource_id: str
resource_id: str | ULID


class Change(Enum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ components:
type: string
minLength: 26
maxLength: 26
pattern: "^[A-Z0-9]{26}$" # This is case-insensitive
pattern: "^[0-7][0-9A-HJKMNP-TV-Z]{25}$" # This is case-insensitive
ProviderId:
description: ID of a OAuth2 provider, e.g. "gitlab.com".
type: string
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: api.spec.yaml
# timestamp: 2024-08-06T05:55:35+00:00
# timestamp: 2024-08-13T13:29:50+00:00

from __future__ import annotations

Expand Down Expand Up @@ -153,7 +153,7 @@ class Connection(BaseAPISpec):
description="ULID identifier",
max_length=26,
min_length=26,
pattern="^[A-Z0-9]{26}$",
pattern="^[0-7][0-9A-HJKMNP-TV-Z]{25}$",
)
provider_id: str = Field(
...,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Base models for API specifications."""

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from ulid import ULID


class BaseAPISpec(BaseModel):
Expand All @@ -14,6 +15,12 @@ class Config:
# this rust crate does not support lookahead regex syntax but we need it in this component
regex_engine = "python-re"

@field_validator("id", mode="before", check_fields=False)
@classmethod
def serialize_id(cls, id: str | ULID) -> str:
"""Custom serializer that can handle ULIDs."""
return str(id)


class AuthorizeParams(BaseAPISpec):
"""The schema for the query parameters used in the authorize request."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sanic.log import logger
from sanic.response import JSONResponse
from sanic_ext import validate
from ulid import ULID

import renku_data_services.base_models as base_models
from renku_data_services.base_api.auth import authenticate, only_admins, only_authenticated
Expand Down Expand Up @@ -150,34 +151,34 @@ def get_one(self) -> BlueprintFactoryResponse:
"""Get a specific OAuth2 connection."""

@authenticate(self.authenticator)
async def _get_one(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
async def _get_one(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
connection = await self.connected_services_repo.get_oauth2_connection(
connection_id=connection_id, user=user
)
return validated_json(apispec.Connection, connection)

return "/oauth2/connections/<connection_id>", ["GET"], _get_one
return "/oauth2/connections/<connection_id:ulid>", ["GET"], _get_one

def get_account(self) -> BlueprintFactoryResponse:
"""Get the account information for a specific OAuth2 connection."""

@authenticate(self.authenticator)
async def _get_account(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
async def _get_account(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
account = await self.connected_services_repo.get_oauth2_connected_account(
connection_id=connection_id, user=user
)
return validated_json(apispec.ConnectedAccount, account)

return "/oauth2/connections/<connection_id>/account", ["GET"], _get_account
return "/oauth2/connections/<connection_id:ulid>/account", ["GET"], _get_account

def get_token(self) -> BlueprintFactoryResponse:
"""Get the access token for a specific OAuth2 connection."""

@authenticate(self.authenticator)
async def _get_token(_: Request, user: base_models.APIUser, connection_id: str) -> JSONResponse:
async def _get_token(_: Request, user: base_models.APIUser, connection_id: ULID) -> JSONResponse:
token = await self.connected_services_repo.get_oauth2_connection_token(
connection_id=connection_id, user=user
)
return json(token.dump_for_api())

return "/oauth2/connections/<connection_id>/token", ["GET"], _get_token
return "/oauth2/connections/<connection_id:ulid>/token", ["GET"], _get_token
11 changes: 7 additions & 4 deletions components/renku_data_services/connected_services/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from ulid import ULID

import renku_data_services.base_models as base_models
from renku_data_services import errors
Expand Down Expand Up @@ -282,7 +283,7 @@ async def get_oauth2_connections(
connections = result.all()
return [c.dump() for c in connections]

async def get_oauth2_connection(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2Connection:
async def get_oauth2_connection(self, connection_id: ULID, user: base_models.APIUser) -> models.OAuth2Connection:
"""Get one OAuth2 connection from the database."""
if not user.is_authenticated or user.id is None:
raise errors.MissingResourceError(
Expand All @@ -303,7 +304,7 @@ async def get_oauth2_connection(self, connection_id: str, user: base_models.APIU
return connection.dump()

async def get_oauth2_connected_account(
self, connection_id: str, user: base_models.APIUser
self, connection_id: ULID, user: base_models.APIUser
) -> models.ConnectedAccount:
"""Get the account information from a OAuth2 connection."""
async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, adapter):
Expand All @@ -316,7 +317,9 @@ async def get_oauth2_connected_account(
account = adapter.api_validate_account_response(response)
return account

async def get_oauth2_connection_token(self, connection_id: str, user: base_models.APIUser) -> models.OAuth2TokenSet:
async def get_oauth2_connection_token(
self, connection_id: ULID, user: base_models.APIUser
) -> models.OAuth2TokenSet:
"""Get the OAuth2 access token from one connection from the database."""
async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, _):
await oauth2_client.ensure_active_token(oauth2_client.token)
Expand All @@ -325,7 +328,7 @@ async def get_oauth2_connection_token(self, connection_id: str, user: base_model

@asynccontextmanager
async def get_async_oauth2_client(
self, connection_id: str, user: base_models.APIUser
self, connection_id: ULID, user: base_models.APIUser
) -> AsyncGenerator[tuple[AsyncOAuth2Client, schemas.OAuth2ConnectionORM, ProviderAdapter], None]:
"""Get the AsyncOAuth2Client for the given connection_id and user."""
if not user.is_authenticated or user.id is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from datetime import UTC, datetime
from typing import Any

from ulid import ULID

from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind


Expand All @@ -28,7 +30,7 @@ class OAuth2Client:
class OAuth2Connection:
"""OAuth2 connection model."""

id: str
id: ULID
provider_id: str
status: ConnectionStatus

Expand Down
3 changes: 2 additions & 1 deletion components/renku_data_services/connected_services/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from renku_data_services.connected_services import models
from renku_data_services.connected_services.apispec import ConnectionStatus, ProviderKind
from renku_data_services.utils.sqlalchemy import ULIDType

JSONVariant = JSON().with_variant(JSONB(), "postgresql")

Expand Down Expand Up @@ -72,7 +73,7 @@ class OAuth2ConnectionORM(BaseORM):
"""An OAuth2 connection."""

__tablename__ = "oauth2_connections"
id: Mapped[str] = mapped_column("id", String(26), primary_key=True, default_factory=lambda: str(ULID()), init=False)
id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
user_id: Mapped[str] = mapped_column("user_id", String())
client_id: Mapped[str] = mapped_column(ForeignKey(OAuth2ClientORM.id, ondelete="CASCADE"), index=True)
client: Mapped[OAuth2ClientORM] = relationship(init=False, repr=False)
Expand Down
2 changes: 1 addition & 1 deletion components/renku_data_services/crc/apispec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# generated by datamodel-codegen:
# filename: api.spec.yaml
# timestamp: 2024-08-07T07:12:13+00:00
# timestamp: 2024-08-13T13:29:45+00:00

from __future__ import annotations

Expand Down
32 changes: 18 additions & 14 deletions components/renku_data_services/message_queue/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,14 @@ def to_events(
raise errors.EventError(
message=f"Cannot create an event of type {event_type} for a project which has no ID"
)
project_id_str = str(project.id)
match event_type:
case v2.ProjectCreated:
return [
Event(
"project.created",
v2.ProjectCreated(
id=project.id,
id=project_id_str,
name=project.name,
namespace=project.namespace.slug,
slug=project.slug,
Expand All @@ -56,7 +57,7 @@ def to_events(
Event(
"projectAuth.added",
v2.ProjectMemberAdded(
projectId=project.id,
projectId=project_id_str,
userId=project.created_by,
role=v2.MemberRole.OWNER,
),
Expand All @@ -67,7 +68,7 @@ def to_events(
Event(
"project.updated",
v2.ProjectUpdated(
id=project.id,
id=project_id_str,
name=project.name,
namespace=project.namespace.slug,
slug=project.slug,
Expand All @@ -79,7 +80,7 @@ def to_events(
)
]
case v2.ProjectRemoved:
return [Event("project.removed", v2.ProjectRemoved(id=project.id))]
return [Event("project.removed", v2.ProjectRemoved(id=project_id_str))]
case _:
raise errors.EventError(message=f"Trying to convert a project to an unknown event type {event_type}")

Expand Down Expand Up @@ -145,13 +146,14 @@ class _ProjectAuthzEventConverter:
def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event]:
output: list[Event] = []
for change in member_changes:
resource_id = str(change.member.resource_id)
match change.change:
case authz_models.Change.UPDATE:
output.append(
Event(
"projectAuth.updated",
v2.ProjectMemberUpdated(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
role=_convert_member_role(change.member.role),
),
Expand All @@ -162,7 +164,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event
Event(
"projectAuth.removed",
v2.ProjectMemberRemoved(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
),
)
Expand All @@ -172,7 +174,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event
Event(
"projectAuth.added",
v2.ProjectMemberAdded(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
role=_convert_member_role(change.member.role),
),
Expand All @@ -191,13 +193,14 @@ class _GroupAuthzEventConverter:
def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event]:
output: list[Event] = []
for change in member_changes:
resource_id = str(change.member.resource_id)
match change.change:
case authz_models.Change.UPDATE:
output.append(
Event(
"memberGroup.updated",
v2.ProjectMemberUpdated(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
role=_convert_member_role(change.member.role),
),
Expand All @@ -208,7 +211,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event
Event(
"memberGroup.removed",
v2.ProjectMemberRemoved(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
),
)
Expand All @@ -218,7 +221,7 @@ def to_events(member_changes: list[authz_models.MembershipChange]) -> list[Event
Event(
"memberGroup.added",
v2.ProjectMemberAdded(
projectId=change.member.resource_id,
projectId=resource_id,
userId=change.member.user_id,
role=_convert_member_role(change.member.role),
),
Expand All @@ -239,32 +242,33 @@ def to_events(group: group_models.Group, event_type: type[AvroModel] | type[even
raise errors.ProgrammingError(
message="Cannot send group events to the message queue for a group that does not have an ID"
)
group_id = str(group.id)
match event_type:
case v2.GroupAdded:
return [
Event(
"group.added",
v2.GroupAdded(
id=group.id, name=group.name, description=group.description, namespace=group.slug
id=group_id, name=group.name, description=group.description, namespace=group.slug
),
),
Event(
"memberGroup.added",
v2.GroupMemberAdded(
groupId=group.id,
groupId=group_id,
userId=group.created_by,
role=v2.MemberRole.OWNER,
),
),
]
case v2.GroupRemoved:
return [Event("group.removed", v2.GroupRemoved(id=group.id))]
return [Event("group.removed", v2.GroupRemoved(id=group_id))]
case v2.GroupUpdated:
return [
Event(
"group.updated",
v2.GroupUpdated(
id=group.id, name=group.name, description=group.description, namespace=group.slug
id=group_id, name=group.name, description=group.description, namespace=group.slug
),
)
]
Expand Down
Loading
Loading