Skip to content

Commit 4fce9c2

Browse files
committed
fix(rbac): review and linting
1 parent 2d47913 commit 4fce9c2

File tree

6 files changed

+136
-24
lines changed

6 files changed

+136
-24
lines changed

alembic/versions/4bb7e59026f3_drop_role_from_membership_and_.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,55 @@
2222

2323
def upgrade() -> None:
2424
# ### commands auto generated by Alembic - please adjust! ###
25-
op.add_column("invitation", sa.Column("role_id", sa.UUID(), nullable=False))
25+
#
26+
# This migration converts enum-based roles to the new Role table.
27+
# Existing invitations are migrated by looking up the corresponding role
28+
# by slug in the organization's role table.
29+
#
30+
# Enum to slug mapping:
31+
# - WorkspaceRole: VIEWER -> viewer, EDITOR -> editor, ADMIN -> admin
32+
# - OrgRole: MEMBER -> member, ADMIN -> admin, OWNER -> owner
33+
34+
# Step 1: Add role_id columns as nullable
35+
op.add_column("invitation", sa.Column("role_id", sa.UUID(), nullable=True))
36+
op.add_column(
37+
"organization_invitation", sa.Column("role_id", sa.UUID(), nullable=True)
38+
)
39+
40+
# Step 2: Populate role_id for existing workspace invitations
41+
# Map old enum values to role slugs and join through workspace to get org
42+
op.execute(
43+
"""
44+
UPDATE invitation i
45+
SET role_id = r.id
46+
FROM workspace w, role r
47+
WHERE i.workspace_id = w.id
48+
AND r.organization_id = w.organization_id
49+
AND r.slug = LOWER(i.role::text)
50+
"""
51+
)
52+
53+
# Step 3: Populate role_id for existing organization invitations
54+
op.execute(
55+
"""
56+
UPDATE organization_invitation oi
57+
SET role_id = r.id
58+
FROM role r
59+
WHERE r.organization_id = oi.organization_id
60+
AND r.slug = LOWER(oi.role::text)
61+
"""
62+
)
63+
64+
# Step 4: Delete any invitations that couldn't be migrated (no matching role)
65+
# This handles edge cases where roles weren't seeded yet
66+
op.execute("DELETE FROM invitation WHERE role_id IS NULL")
67+
op.execute("DELETE FROM organization_invitation WHERE role_id IS NULL")
68+
69+
# Step 5: Alter columns to NOT NULL
70+
op.alter_column("invitation", "role_id", nullable=False)
71+
op.alter_column("organization_invitation", "role_id", nullable=False)
72+
73+
# Step 6: Create indexes and foreign keys
2674
op.create_index(
2775
op.f("ix_invitation_role_id"), "invitation", ["role_id"], unique=False
2876
)
@@ -34,11 +82,6 @@ def upgrade() -> None:
3482
["id"],
3583
ondelete="RESTRICT",
3684
)
37-
op.drop_column("invitation", "role")
38-
op.drop_column("membership", "role")
39-
op.add_column(
40-
"organization_invitation", sa.Column("role_id", sa.UUID(), nullable=False)
41-
)
4285
op.create_index(
4386
op.f("ix_organization_invitation_role_id"),
4487
"organization_invitation",
@@ -53,8 +96,14 @@ def upgrade() -> None:
5396
["id"],
5497
ondelete="RESTRICT",
5598
)
99+
100+
# Step 7: Drop old enum-based role columns
101+
op.drop_column("invitation", "role")
102+
op.drop_column("membership", "role")
56103
op.drop_column("organization_invitation", "role")
57104
op.drop_column("organization_membership", "role")
105+
106+
# Step 8: Drop the old enum types
58107
sa.Enum("MEMBER", "ADMIN", "OWNER", name="orgrole").drop(op.get_bind())
59108
sa.Enum("VIEWER", "EDITOR", "ADMIN", name="workspacerole").drop(op.get_bind())
60109
# ### end Alembic commands ###

tracecat/authz/seeding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
ORG_OWNER_SCOPES,
2626
VIEWER_SCOPES,
2727
)
28-
from tracecat.db.models import Role, RoleScope, Scope
28+
from tracecat.db.models import Organization, Role, RoleScope, Scope
2929
from tracecat.logger import logger
3030

3131
if TYPE_CHECKING:
@@ -360,7 +360,9 @@ async def seed_system_roles_for_org(
360360
# Get all system scope names -> IDs
361361
scope_stmt = select(Scope.id, Scope.name).where(Scope.organization_id.is_(None))
362362
scope_result = await session.execute(scope_stmt)
363-
scope_name_to_id: dict[str, UUID] = {name: id_ for id_, name in scope_result.all()}
363+
scope_name_to_id: dict[str, UUID] = {
364+
name: id_ for id_, name in scope_result.tuples().all()
365+
}
364366

365367
roles_created = 0
366368

@@ -445,8 +447,6 @@ async def seed_system_roles_for_all_orgs(session: AsyncSession) -> dict[UUID, in
445447
Returns:
446448
Dict mapping organization_id to number of roles created
447449
"""
448-
from tracecat.db.models import Organization
449-
450450
logger.info("Seeding system roles for all organizations")
451451

452452
# Get all organizations

tracecat/authz/service.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Sequence
44
from dataclasses import dataclass
55

6-
from sqlalchemy import select
6+
from sqlalchemy import or_, select
77
from sqlalchemy.ext.asyncio import AsyncSession
88

99
from tracecat.auth.types import Role
@@ -83,17 +83,37 @@ async def list_workspace_members(
8383
return []
8484

8585
# Get role assignments for these users in this workspace
86+
# Include both workspace-specific assignments and org-wide assignments (workspace_id IS NULL)
8687
user_ids = [u.id for u in users]
8788
role_stmt = (
88-
select(UserRoleAssignment.user_id, RoleModel.slug)
89+
select(
90+
UserRoleAssignment.user_id,
91+
UserRoleAssignment.workspace_id,
92+
RoleModel.slug,
93+
)
8994
.join(RoleModel, UserRoleAssignment.role_id == RoleModel.id)
9095
.where(
9196
UserRoleAssignment.user_id.in_(user_ids),
92-
UserRoleAssignment.workspace_id == workspace_id,
97+
or_(
98+
UserRoleAssignment.workspace_id == workspace_id,
99+
UserRoleAssignment.workspace_id.is_(None),
100+
),
93101
)
94102
)
95103
role_result = await self.session.execute(role_stmt)
96-
user_role_map = {row[0]: row[1] for row in role_result.all()}
104+
105+
# Build map preferring workspace-specific assignments over org-wide
106+
user_role_map: dict[UserID, str] = {}
107+
for uid, ws_id, slug in role_result.tuples().all():
108+
if slug is None:
109+
# Skip assignments without a slug (custom roles)
110+
continue
111+
if ws_id is not None:
112+
# Workspace-specific assignment takes precedence
113+
user_role_map[uid] = slug
114+
elif uid not in user_role_map:
115+
# Org-wide assignment as fallback
116+
user_role_map[uid] = slug
97117

98118
return [
99119
WorkspaceMember(

tracecat/organization/schemas.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from datetime import datetime
2+
from typing import Self
23
from uuid import UUID
34

4-
from pydantic import BaseModel, EmailStr
5+
from pydantic import BaseModel, EmailStr, model_validator
56

67
from tracecat.identifiers import OrganizationID, UserID
78
from tracecat.invitations.enums import InvitationStatus
@@ -48,6 +49,13 @@ class OrgInvitationCreate(BaseModel):
4849
role_slug: str | None = None
4950
"""Slug of the role to grant (e.g., 'admin', 'member', 'owner')."""
5051

52+
@model_validator(mode="after")
53+
def validate_role_specified(self) -> Self:
54+
"""Ensure at least one of role_id or role_slug is provided."""
55+
if self.role_id is None and self.role_slug is None:
56+
raise ValueError("Either role_id or role_slug must be provided")
57+
return self
58+
5159

5260
class OrgInvitationRead(BaseModel):
5361
"""Response model for organization invitation."""

tracecat/registry/sync/jobs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import tracecat_registry
1111
from packaging.version import Version
12+
from sqlalchemy.exc import DBAPIError
1213
from sqlalchemy.ext.asyncio import AsyncSession
1314

1415
from tracecat.authz.seeding import seed_registry_scopes_bulk
@@ -196,7 +197,7 @@ async def _seed_registry_scopes(
196197
inserted = await seed_registry_scopes_bulk(session, action_keys)
197198
await session.commit()
198199
logger.info("Registry scopes seeded", inserted=inserted, total=len(action_keys))
199-
except Exception as e:
200+
except DBAPIError as e:
200201
logger.warning("Failed to seed registry scopes", error=str(e))
201-
# Don't fail the sync if scope seeding fails
202+
# Don't fail the sync if scope seeding fails due to DB errors
202203
await session.rollback()

tracecat/workspaces/router.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
HTTPException,
77
status,
88
)
9-
from sqlalchemy import select
9+
from sqlalchemy import or_, select
1010
from sqlalchemy.exc import IntegrityError, NoResultFound
1111

1212
from tracecat.auth.credentials import RoleACL
@@ -253,17 +253,37 @@ async def list_workspace_memberships(
253253
return []
254254

255255
# Look up roles from RBAC tables
256+
# Include both workspace-specific assignments and org-wide assignments (workspace_id IS NULL)
256257
user_ids = [m.user_id for m in memberships]
257258
role_stmt = (
258-
select(UserRoleAssignment.user_id, RoleModel.slug)
259+
select(
260+
UserRoleAssignment.user_id,
261+
UserRoleAssignment.workspace_id,
262+
RoleModel.slug,
263+
)
259264
.join(RoleModel, UserRoleAssignment.role_id == RoleModel.id)
260265
.where(
261266
UserRoleAssignment.user_id.in_(user_ids),
262-
UserRoleAssignment.workspace_id == workspace_id,
267+
or_(
268+
UserRoleAssignment.workspace_id == workspace_id,
269+
UserRoleAssignment.workspace_id.is_(None),
270+
),
263271
)
264272
)
265273
role_result = await session.execute(role_stmt)
266-
user_role_map = {row[0]: row[1] for row in role_result.all()}
274+
275+
# Build map preferring workspace-specific assignments over org-wide
276+
user_role_map: dict[UserID, str] = {}
277+
for uid, ws_id, slug in role_result.tuples().all():
278+
if slug is None:
279+
# Skip assignments without a slug (custom roles)
280+
continue
281+
if ws_id is not None:
282+
# Workspace-specific assignment takes precedence
283+
user_role_map[uid] = slug
284+
elif uid not in user_role_map:
285+
# Org-wide assignment as fallback
286+
user_role_map[uid] = slug
267287

268288
return [
269289
WorkspaceMembershipRead(
@@ -359,16 +379,30 @@ async def get_workspace_membership(
359379
membership = membership_with_org.membership
360380

361381
# Look up role from RBAC tables
382+
# Include both workspace-specific assignments and org-wide assignments (workspace_id IS NULL)
362383
role_stmt = (
363-
select(RoleModel.slug)
384+
select(UserRoleAssignment.workspace_id, RoleModel.slug)
364385
.join(UserRoleAssignment, UserRoleAssignment.role_id == RoleModel.id)
365386
.where(
366387
UserRoleAssignment.user_id == user_id,
367-
UserRoleAssignment.workspace_id == workspace_id,
388+
or_(
389+
UserRoleAssignment.workspace_id == workspace_id,
390+
UserRoleAssignment.workspace_id.is_(None),
391+
),
368392
)
369393
)
370394
role_result = await session.execute(role_stmt)
371-
slug = role_result.scalar_one_or_none()
395+
rows = role_result.tuples().all()
396+
397+
# Prefer workspace-specific assignment over org-wide
398+
slug: str | None = None
399+
for ws_id, role_slug in rows:
400+
if ws_id is not None:
401+
# Workspace-specific assignment takes precedence
402+
slug = role_slug
403+
break
404+
# Org-wide assignment as fallback
405+
slug = role_slug
372406
workspace_role = SLUG_TO_WORKSPACE_ROLE.get(slug or "", WorkspaceRole.VIEWER)
373407

374408
return WorkspaceMembershipRead(

0 commit comments

Comments
 (0)