Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions statgpt/admin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ the [common README file](../common/README.md).
| OIDC_CLIENT_ID | Yes, if `$OIDC_AUTH_ENABLED` | OIDC Client ID | | |
| OIDC_ISSUER | Yes, if `$OIDC_AUTH_ENABLED` | OIDC Issuer | | |
| OIDC_USERNAME_CLAIM | Yes, if `$OIDC_AUTH_ENABLED` | OIDC Username Claim | | |
| OIDC_AUDIT_USER_ID_CLAIM | No | JWT claim(s) used to populate audit log `performed_by` (single value, comma-separated string, or JSON array) | | `oid,sub` |
| OIDC_AUDIT_PERFORMED_BY_NAME_CLAIM | No | JWT claim(s) used to populate audit log `performed_by_name` (single value, comma-separated string, or JSON array) | | `unique_name,email` |
| ADMIN_ROLES_CLAIM | Yes, if `$OIDC_AUTH_ENABLED` | OIDC Admin Roles Claim | | |
| ADMIN_ROLES_VALUES | Yes, if `$OIDC_AUTH_ENABLED` | OIDC Admin Roles Values | | |
| ADMIN_SCOPE_CLAIM_VALIDATION_ENABLED | No | If specified, the admin portal will check for scopes in the OIDC token, otherwise this check will be skipped. | `true`, `false` | `true` |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Add immutable audit logs table

Revision ID: 3b8a6a40f1cd
Revises: c7f068b2d47d
Create Date: 2026-02-11 12:00:00.000000

"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = '3b8a6a40f1cd'
down_revision: str | None = 'c7f068b2d47d'
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
audit_entity_type = postgresql.ENUM(
'channel',
'dataset',
'data_source',
'import_job',
name='auditentitytype',
create_type=False,
)
audit_action_type = postgresql.ENUM(
'create',
'update',
'delete',
name='auditactiontype',
create_type=False,
)
audit_entity_type.create(op.get_bind(), checkfirst=True)
audit_action_type.create(op.get_bind(), checkfirst=True)

op.create_table(
'audit_logs',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('entity_type', audit_entity_type, nullable=False),
sa.Column('action_type', audit_action_type, nullable=False),
sa.Column('item_id', sa.Integer(), nullable=False),
sa.Column('entity_id', sa.String(), nullable=False),
sa.Column('entity_name', sa.String(), nullable=False),
sa.Column('performed_by', sa.String(), nullable=False),
sa.Column('performed_by_name', sa.String(), nullable=False),
sa.Column('state_after', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('trace_id', sa.String(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()')),
sa.PrimaryKeyConstraint('id'),
)
op.create_index('ix_audit_logs_created_at', 'audit_logs', ['created_at'], unique=False)
op.create_index(
'ix_audit_logs_entity_type_entity_id',
'audit_logs',
['entity_type', 'entity_id'],
unique=False,
)
op.create_index(
'ix_audit_logs_entity_type_item_id',
'audit_logs',
['entity_type', 'item_id'],
unique=False,
)
op.execute(
"""
CREATE OR REPLACE FUNCTION prevent_audit_log_mutation()
RETURNS trigger AS $$
BEGIN
RAISE EXCEPTION 'audit_logs rows are immutable';
END;
$$ LANGUAGE plpgsql;
"""
)
op.execute(
"""
CREATE TRIGGER trg_prevent_audit_log_mutation
BEFORE UPDATE OR DELETE ON audit_logs
FOR EACH ROW
EXECUTE FUNCTION prevent_audit_log_mutation();
"""
)


def downgrade() -> None:
op.execute("DROP TRIGGER IF EXISTS trg_prevent_audit_log_mutation ON audit_logs")
op.execute("DROP FUNCTION IF EXISTS prevent_audit_log_mutation()")
op.drop_index('ix_audit_logs_entity_type_item_id', table_name='audit_logs')
op.drop_index('ix_audit_logs_entity_type_entity_id', table_name='audit_logs')
op.drop_index('ix_audit_logs_created_at', table_name='audit_logs')
op.drop_table('audit_logs')
postgresql.ENUM(name='auditactiontype').drop(op.get_bind(), checkfirst=True)
postgresql.ENUM(name='auditentitytype').drop(op.get_bind(), checkfirst=True)
2 changes: 2 additions & 0 deletions statgpt/admin/audit/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .context import get_audit_context, set_audit_context
from .decorators import audit_action
40 changes: 40 additions & 0 deletions statgpt/admin/audit/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from contextvars import ContextVar
from dataclasses import dataclass, field

from opentelemetry import trace


def _get_trace_id() -> str:
span_context = trace.get_current_span().get_span_context()
if not span_context.is_valid:
raise RuntimeError("No valid span context available for trace ID")
return format(span_context.trace_id, "032x")


@dataclass(frozen=True)
class AuditContext:
performed_by: str
performed_by_name: str
trace_id: str = field(default_factory=_get_trace_id)


_audit_context_var: ContextVar[AuditContext] = ContextVar("admin_audit_context")


def set_audit_context(*, performed_by: str, performed_by_name: str) -> None:
_audit_context_var.set(
AuditContext(
performed_by=performed_by,
performed_by_name=performed_by_name,
)
)


def update_audit_context(audit_context: AuditContext) -> None:
_audit_context_var.set(audit_context)


def get_audit_context() -> AuditContext:
if res := _audit_context_var.get(None):
return res
raise RuntimeError("Audit context not set")
66 changes: 66 additions & 0 deletions statgpt/admin/audit/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import logging
from collections.abc import Awaitable, Callable
from functools import wraps

from sqlalchemy.ext.asyncio import AsyncSession

import statgpt.common.models as models
from statgpt.admin.audit.context import get_audit_context
from statgpt.common.schemas.auditable import Auditable
from statgpt.common.schemas.enums import AuditActionType, AuditEntityType

_log = logging.getLogger(__name__)


async def _persist_audit_log(
*,
session: AsyncSession,
entity_type: AuditEntityType,
action_type: AuditActionType,
data: Auditable,
) -> None:
context = get_audit_context()
state_after = None if action_type is AuditActionType.DELETE else data.get_state_after()

item = models.AuditLog(
entity_type=entity_type,
action_type=action_type,
item_id=data.get_item_id(),
entity_id=data.get_entity_id(),
entity_name=data.get_entity_name(),
performed_by=context.performed_by,
performed_by_name=context.performed_by_name,
state_after=state_after,
trace_id=context.trace_id,
)
session.add(item)
await session.commit()


def audit_action(
*,
entity_type: AuditEntityType,
action_type: AuditActionType,
):
def decorator(func: Callable[..., Awaitable[Auditable]]) -> Callable[..., Awaitable[Auditable]]:
@wraps(func)
async def wrapped(self, *args, **kwargs) -> Auditable:
result: Auditable = await func(self, *args, **kwargs)
try:
await _persist_audit_log(
session=self._session,
entity_type=entity_type,
action_type=action_type,
data=result,
)
except Exception:
_log.exception(
f"Failed to persist audit log for {entity_type} action={action_type}"
)
# TODO: Probably we should also roll back the session here

return result

return wrapped

return decorator
21 changes: 19 additions & 2 deletions statgpt/admin/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,35 @@ class TokenPayload:
def __init__(self, payload: dict):
self._payload = payload

def _extract_claim(self, claim_name: str, claim_keys: list[str]) -> str:
for key in claim_keys:
value = self._payload.get(key)
if value is not None and value != "":
return str(value)
raise InvalidRequestError(f"{claim_name} claim {claim_keys} not found in token")

@property
def raw(self) -> dict:
return self._payload

@property
def username(self):
def username(self) -> str:
username = self._payload.get(oidc_auth_settings.oidc_username_claim, None)
if not username:
raise InvalidRequestError(
f"Username claim {oidc_auth_settings.oidc_username_claim} not found in token"
)
return username
return str(username)

@property
def user_id(self) -> str:
return self._extract_claim("User ID", oidc_auth_settings.oidc_audit_user_id_claims)

@property
def performed_by_name(self) -> str:
return self._extract_claim(
"Performed by name", oidc_auth_settings.oidc_audit_performed_by_name_claims
)


class Jwks:
Expand Down
21 changes: 15 additions & 6 deletions statgpt/admin/auth/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from fastapi.security import OAuth2PasswordBearer
from jwt import InvalidTokenError

from statgpt.admin.audit.context import set_audit_context
from statgpt.admin.auth.oidc import JwtTokenVerifier, TokenValidationError, TokenValidator
from statgpt.admin.settings.oidc_auth import oidc_auth_settings
from statgpt.common.config import logger
Expand All @@ -13,13 +14,14 @@
)


@dataclass
@dataclass(frozen=True)
class User:
id: str
username: str
name: str


async def require_jwt_auth(token: str = Depends(oauth2_scheme)) -> User:

async def _require_jwt_auth(token: str = Depends(oauth2_scheme)) -> User:
if oidc_auth_settings.oidc_auth_enabled:
try:
payload = JwtTokenVerifier.create().verify(token)
Expand All @@ -29,13 +31,20 @@ async def require_jwt_auth(token: str = Depends(oauth2_scheme)) -> User:
logger.info(f"Unauthorized token: {str(e)}")
raise HTTPException(status_code=403, detail=str(e))

return User(payload.username)
return User(
id=payload.user_id, username=payload.username, name=payload.performed_by_name
)
except InvalidTokenError as e:
logger.info(f"Invalid Bearer token: {str(e)}")
logger.warning(f"Invalid Bearer token: {str(e)}")
raise HTTPException(
status_code=401,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
else:
return User("Anonymous")
return User(id="Anonymous", username="Anonymous", name="Anonymous")


async def require_jwt_auth(user: User = Depends(_require_jwt_auth)) -> User:
set_audit_context(performed_by=user.id, performed_by_name=user.name)
return user
10 changes: 9 additions & 1 deletion statgpt/admin/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from fastapi import APIRouter

from .audit_log import router as audit_log_router
from .channel import router as channel_router
from .data_source import router as data_source_router
from .dataset import router as dataset_router
Expand All @@ -10,5 +11,12 @@

channel_router.include_router(channel_terms_router)

for r in (channel_router, data_source_router, dataset_router, terms_router, health_check):
for r in (
channel_router,
data_source_router,
dataset_router,
terms_router,
audit_log_router,
health_check,
):
router.include_router(r)
Loading
Loading