Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions components/renku_data_services/base_api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def decorator(
async def decorated_function(request: Request, *args: _P.args, **kwargs: _P.kwargs) -> _T:
token = request.headers.get(authenticator.token_field)
user = await authenticator.authenticate(token or "", request)
request.ctx.keycloak_user_id = user.id
response = await f(request, user, *args, **kwargs)
return response

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""create audit tables

Revision ID: 4fdd23ef602d
Revises: 559b1fc46cfe
Create Date: 2025-03-21 10:46:16.687230

"""

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import INET
from sqlalchemy.dialects.postgresql.ext import ExcludeConstraint
from sqlalchemy.dialects.postgresql.json import JSONB

from renku_data_services.project.orm import ProjectORM
from renku_data_services.utils.sanic_pgaudit import versioning_manager

# revision identifiers, used by Alembic.
revision = "4fdd23ef602d"
down_revision = "559b1fc46cfe"
branch_labels = None
depends_on = None


def upgrade() -> None:
# set up versioning manager tables
op.execute("CREATE EXTENSION IF NOT EXISTS btree_gist")
op.execute("""CREATE OR REPLACE FUNCTION jsonb_change_key_name(data jsonb, old_key text, new_key text)
RETURNS jsonb
IMMUTABLE
LANGUAGE sql
AS $$
SELECT ('{'||string_agg(to_json(CASE WHEN key = old_key THEN new_key ELSE key END)||':'||value, ',')||'}')::jsonb
FROM (
SELECT *
FROM jsonb_each(data)
) t;
$$;""")
op.create_table(
"transaction",
sa.Column("id", sa.BigInteger, primary_key=True),
sa.Column("native_transaction_id", sa.BigInteger),
sa.Column("issued_at", sa.DateTime),
sa.Column("client_addr", INET),
sa.Column("actor_id", sa.Text(), nullable=True),
ExcludeConstraint(
(sa.Column("native_transaction_id"), "="),
(
sa.func.tsrange(
sa.Column("issued_at") - sa.text("INTERVAL '1 hour'"),
sa.Column("issued_at"),
),
"&&",
),
using="gist",
name="transaction_unique_native_tx_id",
),
schema="common",
)
op.create_table(
"activity",
sa.Column("id", sa.BigInteger, primary_key=True),
sa.Column("schema_name", sa.Text),
sa.Column("table_name", sa.Text),
sa.Column("relid", sa.Integer),
sa.Column("issued_at", sa.DateTime),
sa.Column("native_transaction_id", sa.BigInteger, index=True),
sa.Column("verb", sa.Text),
sa.Column("old_data", JSONB, default={}, server_default="{}"),
sa.Column("changed_data", JSONB, default={}, server_default="{}"),
sa.Column("transaction_id", sa.BigInteger),
sa.ForeignKeyConstraint(
["transaction_id"],
["common.transaction.id"],
),
schema="common",
)
op.create_index(
op.f("ix_activity_native_transaction_id"), "activity", ["native_transaction_id"], schema="common", unique=False
)

# set up versioning manager triggers
versioning_manager.create_audit_table(None, op.get_bind())
versioning_manager.create_operators(None, op.get_bind())

# manually set up version tracking for projects
# pgsql_audit does this automatically, but with an "table.after_create" trigger, so this doesn't work for
# existing tables
query = versioning_manager.build_audit_table_query(
table=ProjectORM.__table__, exclude_columns=ProjectORM.__versioned__.get("exclude")
)
op.execute(query)


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f("ix_activity_native_transaction_id"), table_name="activity", schema="common")
op.drop_table("activity", schema="common")
op.drop_table("transaction", schema="common")
# ### end Alembic commands ###
12 changes: 12 additions & 0 deletions components/renku_data_services/project/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,23 @@
from renku_data_services.project.apispec import Visibility
from renku_data_services.secrets.orm import SecretORM
from renku_data_services.users.orm import UserORM
from renku_data_services.utils.sanic_pgaudit import versioning_manager
from renku_data_services.utils.sqlalchemy import PurePosixPathType, ULIDType

if TYPE_CHECKING:
from renku_data_services.namespace.orm import EntitySlugORM


class CommonORM(MappedAsDataclass, DeclarativeBase):
"""Base class for common schema."""

metadata = MetaData(schema="common")
registry = COMMON_ORM_REGISTRY


versioning_manager.init(CommonORM)


class BaseORM(MappedAsDataclass, DeclarativeBase):
"""Base class for all ORM classes."""

Expand All @@ -36,6 +47,7 @@ class ProjectORM(BaseORM):

__tablename__ = "projects"
__table_args__ = (Index("ix_projects_project_template_id", "template_id"),)
__versioned__ = {"exclude": ["creation_date"]}
id: Mapped[ULID] = mapped_column("id", ULIDType, primary_key=True, default_factory=lambda: str(ULID()), init=False)
name: Mapped[str] = mapped_column("name", String(99))
visibility: Mapped[Visibility]
Expand Down
4 changes: 2 additions & 2 deletions components/renku_data_services/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Package for shared utility functionality."""

from renku_data_services.utils import core, etag
from renku_data_services.utils import core, cryptography, etag, middleware, sanic_pgaudit, sqlalchemy

__all__ = ["core", "etag"]
__all__ = ["core", "etag", "sanic_pgaudit", "sqlalchemy", "middleware", "cryptography"]
121 changes: 121 additions & 0 deletions components/renku_data_services/utils/sanic_pgaudit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Postgres Audit log functionality.

This mostly just overrides the flask functionality already in pg audit so it works with sanic.
"""

from contextlib import contextmanager
from copy import copy
from typing import Any, Iterator, reveal_type

import sqlalchemy as sa
from postgresql_audit.base import ImproperlyConfigured
from postgresql_audit.base import VersioningManager as BaseVersioningManager
from sanic import Request, Sanic
from sqlalchemy import FromClause, orm
from sqlalchemy.dialects.postgresql import array

from renku_data_services.users.orm import UserORM


def assign_actor(base, cls):
"""Postgresql_audit by default links on primary key, we customize this because we link on keycloak_id."""
if hasattr(cls, "actor_id"):
return

cls.actor_id = sa.Column("actor_id", sa.Text())
cls.actor = orm.relationship(UserORM, primaryjoin=cls.actor_id == UserORM.keycloak_id, foreign_keys=[cls.actor_id])


class SanicVersioningManager(BaseVersioningManager):
"""Custom version manager that integrates with Sanic to get user id."""

_actor_cls = "UserORM"

def get_transaction_values(self):
"""Gets values from Sanic for a pgsql_audit transaction."""
values = copy(self.values)
ctx = Sanic.get_app().ctx
if ctx and hasattr(ctx, "activity_values"):
values.update(ctx.activity_values)
if "actor_id" not in values and self.default_actor_id is not None:
values["actor_id"] = self.default_actor_id
return values

@property
def default_actor_id(self):
"""Get user id from sanic."""
request = Request.get_current()

try:
return request.ctx.keycloak_user_id
except AttributeError:
return

def configure_versioned_classes(self):
"""Configures all versioned classes that were collected during instrumentation process.

Note: we override this so we can use our own `assign_actor` method.
"""
for cls in self.pending_classes:
self.audit_table(cls.__table__, cls.__versioned__.get("exclude"))
assign_actor(self.base, self.transaction_cls)

def build_audit_table_query(self, table: sa.Table, exclude_columns: list[str] | None = None) -> sa.Select:
"""Builds a query that, when executed, turns on audit tracking for a table.

Note: this is just a copy of the pgsql_audit function, but with support for tables in other schemas.
"""
args: list[Any] = [f"{table.schema}.{table.name}"]
if exclude_columns:
for column in exclude_columns:
if column not in table.c:
raise ImproperlyConfigured(
f"Could not configure versioning. Table '{table.name}'' does "
f"not have a column named '{column}'."
)
args.append(array(exclude_columns))

if self.schema_name is None:
func = sa.func.audit_table
else:
func = getattr(getattr(sa.func, self.schema_name), "audit_table")
return sa.select(func(*args))


def merge_dicts(a: dict[str, Any], b: dict[str, Any]) -> dict[str, Any]:
"""Merges two dictionaries.

This is from the pg audit flask implementation, but can't be imported without also importing flask.
"""
c = copy(a)
c.update(b)
return c


@contextmanager
def activity_values(**values: dict[str, Any]) -> Iterator[None]:
"""Context manager that allows tracking child changes on the parent.

Example:
with activity_values(target_id=str(article.id)):
article.tags = [Tag(name='Some tag')]
db.session.commit()
"""
ctx = Sanic.get_app().ctx
if not ctx:
yield # Needed for contextmanager
return
if hasattr(ctx, "activity_values"):
previous_value = ctx.activity_values
values = merge_dicts(previous_value, values)
else:
previous_value = None
ctx.activity_values = values
yield
if previous_value is None:
del ctx.activity_values
else:
ctx.activity_values = previous_value


versioning_manager = SanicVersioningManager(schema_name="common")
Loading
Loading