Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions tests/unit/test_authz_seeding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
"""Tests for RBAC scope seeding."""

from uuid import uuid4

import pytest
from sqlalchemy import func, select

from tracecat.authz.enums import ScopeSource
from tracecat.authz.seeding import (
PRESET_ROLE_DEFINITIONS,
SYSTEM_SCOPE_DEFINITIONS,
seed_registry_scopes,
seed_system_roles_for_all_orgs,
seed_system_scopes,
)
from tracecat.db.models import Organization, Role, RoleScope, Scope


@pytest.mark.anyio
async def test_seed_system_scopes(session):
"""Test that system scopes are seeded correctly."""
# Seed system scopes
inserted_count = await seed_system_scopes(session)

# Should have inserted all scopes on first run
assert inserted_count == len(SYSTEM_SCOPE_DEFINITIONS)

# Verify scopes exist in database
result = await session.execute(
select(Scope).where(
Scope.source == ScopeSource.PLATFORM,
Scope.organization_id.is_(None),
)
)
scopes = result.scalars().all()
assert len(scopes) == len(SYSTEM_SCOPE_DEFINITIONS)

# Verify each scope has correct attributes
scope_names = {s.name for s in scopes}
expected_names = {name for name, _, _, _ in SYSTEM_SCOPE_DEFINITIONS}
assert scope_names == expected_names


@pytest.mark.anyio
async def test_seed_system_scopes_idempotent(session):
"""Test that seeding is idempotent - running twice doesn't duplicate."""
# Seed once
first_count = await seed_system_scopes(session)
assert first_count == len(SYSTEM_SCOPE_DEFINITIONS)

# Seed again
second_count = await seed_system_scopes(session)
assert second_count == 0 # No new scopes inserted

# Verify count is still the same
result = await session.execute(
select(Scope).where(
Scope.source == ScopeSource.PLATFORM,
Scope.organization_id.is_(None),
)
)
scopes = result.scalars().all()
assert len(scopes) == len(SYSTEM_SCOPE_DEFINITIONS)


@pytest.mark.anyio
async def test_seed_registry_scopes(session):
"""Test bulk seeding of registry scopes."""
action_keys = [
"tools.okta.list_users",
"tools.okta.create_user",
"tools.zendesk.create_ticket",
"core.http_request",
]

org_count_result = await session.execute(select(func.count(Organization.id)))
org_count = org_count_result.scalar_one()

inserted_count = await seed_registry_scopes(session, action_keys)
await session.commit()

assert inserted_count == len(action_keys) * (1 + org_count)

# Verify scopes exist
result = await session.execute(
select(Scope).where(
Scope.source == ScopeSource.PLATFORM,
Scope.organization_id.is_(None),
)
)
scopes = result.scalars().all()
assert len(scopes) >= len(action_keys)

scope_names = {s.name for s in scopes}
for key in action_keys:
assert f"action:{key}:execute" in scope_names

custom_scope_result = await session.execute(
select(func.count(Scope.id)).where(Scope.source == ScopeSource.CUSTOM)
)
assert custom_scope_result.scalar_one() == len(action_keys) * org_count


@pytest.mark.anyio
async def test_seed_registry_scopes_idempotent(session):
"""Test that bulk seeding is idempotent."""
action_keys = ["tools.test.action1", "tools.test.action2"]

org_count_result = await session.execute(select(func.count(Organization.id)))
org_count = org_count_result.scalar_one()

# First seed
first_count = await seed_registry_scopes(session, action_keys)
await session.commit()
assert first_count == len(action_keys) * (1 + org_count)

# Second seed
second_count = await seed_registry_scopes(session, action_keys)
await session.commit()
assert second_count == 0


@pytest.mark.anyio
async def test_seed_registry_scopes_empty(session):
"""Test bulk seeding with empty list."""
inserted_count = await seed_registry_scopes(session, [])
assert inserted_count == 0


@pytest.mark.anyio
async def test_seed_system_roles_for_all_orgs_creates_roles_and_links(session):
"""Seed preset roles for all orgs and link expected system scopes."""
# Ensure scope IDs exist for role->scope links.
await seed_system_scopes(session)

# Add an extra org so the function processes multiple orgs in one call.
extra_org = Organization(
id=uuid4(),
name="Extra test org",
slug=f"extra-test-org-{uuid4().hex[:8]}",
is_active=True,
)
session.add(extra_org)
await session.flush()

# Capture target org IDs in this isolated session.
org_result = await session.execute(select(Organization.id))
org_ids = {org_id for (org_id,) in org_result.tuples().all()}

created_by_org = await seed_system_roles_for_all_orgs(session)

for org_id in org_ids:
assert created_by_org[org_id] == len(PRESET_ROLE_DEFINITIONS)

roles_result = await session.execute(
select(Role.id, Role.organization_id, Role.slug).where(
Role.organization_id.in_(org_ids),
Role.slug.in_(PRESET_ROLE_DEFINITIONS),
)
)
roles = roles_result.tuples().all()
assert len(roles) == len(org_ids) * len(PRESET_ROLE_DEFINITIONS)

role_scope_count_stmt = (
select(Role.slug, func.count(RoleScope.scope_id))
.select_from(Role)
.join(RoleScope, RoleScope.role_id == Role.id)
.where(
Role.organization_id.in_(org_ids),
Role.slug.in_(PRESET_ROLE_DEFINITIONS),
)
.group_by(Role.slug)
)
role_scope_count_result = await session.execute(role_scope_count_stmt)
role_scope_counts = dict(role_scope_count_result.tuples().all())
expected_org_count = len(org_ids)
for slug, role_def in PRESET_ROLE_DEFINITIONS.items():
assert role_scope_counts[slug] == len(role_def.scopes) * expected_org_count


@pytest.mark.anyio
async def test_seed_system_roles_for_all_orgs_idempotent(session):
"""Running role seeding twice should not create duplicate roles."""
await seed_system_scopes(session)

first = await seed_system_roles_for_all_orgs(session)
second = await seed_system_roles_for_all_orgs(session)

assert sum(first.values()) > 0
assert all(created == 0 for created in second.values())


@pytest.mark.anyio
async def test_system_scope_definitions_format(session):
"""Test that all system scope definitions follow the expected format."""
for name, resource, action, description in SYSTEM_SCOPE_DEFINITIONS:
# Name should contain resource and action
assert ":" in name, f"Scope name should contain colon: {name}"

# Description should be non-empty
assert description, f"Scope {name} should have a description"

# Resource and action should be non-empty
assert resource, f"Scope {name} should have a resource"
assert action, f"Scope {name} should have an action"
11 changes: 4 additions & 7 deletions tests/unit/test_rbac_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
ORG_ADMIN_SCOPES,
ORG_MEMBER_SCOPES,
ORG_OWNER_SCOPES,
ORG_ROLE_SCOPES,
PRESET_ROLE_SCOPES,
VIEWER_SCOPES,
)
Expand Down Expand Up @@ -200,10 +199,13 @@ def test_editor_includes_viewer(self):
def test_admin_includes_editor(self):
assert EDITOR_SCOPES.issubset(ADMIN_SCOPES)

def test_system_role_mapping(self):
def test_preset_role_mapping(self):
assert PRESET_ROLE_SCOPES["workspace-viewer"] == VIEWER_SCOPES
assert PRESET_ROLE_SCOPES["workspace-editor"] == EDITOR_SCOPES
assert PRESET_ROLE_SCOPES["workspace-admin"] == ADMIN_SCOPES
assert PRESET_ROLE_SCOPES["organization-owner"] == ORG_OWNER_SCOPES
assert PRESET_ROLE_SCOPES["organization-admin"] == ORG_ADMIN_SCOPES
assert PRESET_ROLE_SCOPES["organization-member"] == ORG_MEMBER_SCOPES


class TestOrgRoleScopes:
Expand All @@ -224,11 +226,6 @@ def test_admin_has_billing_read(self):
def test_member_has_minimal_scopes(self):
assert ORG_MEMBER_SCOPES == frozenset({"org:read", "org:member:read"})

def test_org_role_mapping(self):
assert ORG_ROLE_SCOPES["organization-owner"] == ORG_OWNER_SCOPES
assert ORG_ROLE_SCOPES["organization-admin"] == ORG_ADMIN_SCOPES
assert ORG_ROLE_SCOPES["organization-member"] == ORG_MEMBER_SCOPES


class TestRequireScopeDecorator:
"""Tests for the @require_scope decorator."""
Expand Down
15 changes: 15 additions & 0 deletions tracecat/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
auth_backend,
fastapi_users,
)
from tracecat.authz.seeding import seed_all_system_data
from tracecat.cases.attachments.internal_router import (
router as internal_case_attachments_router,
)
Expand All @@ -74,6 +75,7 @@
from tracecat.cases.triggers.consumer import start_case_trigger_consumer
from tracecat.contexts import ctx_role
from tracecat.db.dependencies import AsyncDBSession
from tracecat.db.engine import get_async_session_context_manager
from tracecat.editor.router import router as editor_router
from tracecat.exceptions import EntitlementRequired, ScopeDeniedError, TracecatException
from tracecat.feature_flags import (
Expand Down Expand Up @@ -154,6 +156,9 @@ async def lifespan(app: FastAPI):

await ensure_default_organization()

async with get_async_session_context_manager() as session:
await setup_rbac_defaults(session)

# Spawn platform registry sync as background task (non-blocking)
# Uses leader election to prevent race conditions across multiple API processes
registry_sync_task = asyncio.create_task(
Expand Down Expand Up @@ -236,6 +241,16 @@ async def setup_workspace_defaults(session: AsyncSession, admin_role: Role):
logger.info("Default workspace already exists, skipping")


async def setup_rbac_defaults(session: AsyncSession):
"""Seed system scopes and roles for RBAC."""
try:
result = await seed_all_system_data(session)
logger.info("RBAC defaults seeded", **result)
except Exception as e:
logger.warning("Failed to seed RBAC defaults", error=str(e))
# Don't fail startup if seeding fails - RBAC tables may not exist yet


# Catch-all exception handler to prevent stack traces from leaking
def validation_exception_handler(request: Request, exc: Exception) -> Response:
"""Improves visiblity of 422 errors."""
Expand Down
Loading