Skip to content

Commit f0f545f

Browse files
committed
fix(invitations): create user role assignment on invite
1 parent 1d4467b commit f0f545f

File tree

5 files changed

+187
-68
lines changed

5 files changed

+187
-68
lines changed

alembic/versions/4bb7e59026f3_drop_role_from_membership_and_.py

Lines changed: 86 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -111,18 +111,31 @@ def upgrade() -> None:
111111

112112
def downgrade() -> None:
113113
# ### commands auto generated by Alembic - please adjust! ###
114+
#
115+
# This downgrade restores the original enum-based role columns by:
116+
# 1. Creating the enum types
117+
# 2. Adding role columns as nullable
118+
# 3. Populating role values from role_id via the role table
119+
# 4. Making the columns NOT NULL
120+
# 5. Dropping the role_id columns and constraints
121+
#
122+
# Note: The original schema did NOT have server_default on invitation/org_invitation
123+
# role columns, but membership/org_membership did have defaults.
124+
125+
# Step 1: Create enum types
114126
sa.Enum("VIEWER", "EDITOR", "ADMIN", name="workspacerole").create(op.get_bind())
115127
sa.Enum("MEMBER", "ADMIN", "OWNER", name="orgrole").create(op.get_bind())
128+
129+
# Step 2: Add role columns as nullable
116130
op.add_column(
117131
"organization_membership",
118132
sa.Column(
119133
"role",
120134
postgresql.ENUM(
121135
"MEMBER", "ADMIN", "OWNER", name="orgrole", create_type=False
122136
),
123-
server_default=sa.text("'MEMBER'::orgrole"),
124137
autoincrement=False,
125-
nullable=False,
138+
nullable=True,
126139
),
127140
)
128141
op.add_column(
@@ -133,28 +146,18 @@ def downgrade() -> None:
133146
"MEMBER", "ADMIN", "OWNER", name="orgrole", create_type=False
134147
),
135148
autoincrement=False,
136-
nullable=False,
149+
nullable=True,
137150
),
138151
)
139-
op.drop_constraint(
140-
op.f("fk_organization_invitation_role_id_role"),
141-
"organization_invitation",
142-
type_="foreignkey",
143-
)
144-
op.drop_index(
145-
op.f("ix_organization_invitation_role_id"), table_name="organization_invitation"
146-
)
147-
op.drop_column("organization_invitation", "role_id")
148152
op.add_column(
149153
"membership",
150154
sa.Column(
151155
"role",
152156
postgresql.ENUM(
153157
"VIEWER", "EDITOR", "ADMIN", name="workspacerole", create_type=False
154158
),
155-
server_default=sa.text("'EDITOR'::workspacerole"),
156159
autoincrement=False,
157-
nullable=False,
160+
nullable=True,
158161
),
159162
)
160163
op.add_column(
@@ -165,9 +168,77 @@ def downgrade() -> None:
165168
"VIEWER", "EDITOR", "ADMIN", name="workspacerole", create_type=False
166169
),
167170
autoincrement=False,
168-
nullable=False,
171+
nullable=True,
169172
),
170173
)
174+
175+
# Step 3: Populate role values from role_id via the role table
176+
# Organization membership: role_id -> role.slug -> uppercase enum
177+
op.execute(
178+
"""
179+
UPDATE organization_membership om
180+
SET role = UPPER(r.slug)::orgrole
181+
FROM role r
182+
WHERE om.role_id = r.id
183+
"""
184+
)
185+
186+
# Organization invitation: role_id -> role.slug -> uppercase enum
187+
op.execute(
188+
"""
189+
UPDATE organization_invitation oi
190+
SET role = UPPER(r.slug)::orgrole
191+
FROM role r
192+
WHERE oi.role_id = r.id
193+
"""
194+
)
195+
196+
# Workspace membership: role_id -> role.slug -> uppercase enum
197+
op.execute(
198+
"""
199+
UPDATE membership m
200+
SET role = UPPER(r.slug)::workspacerole
201+
FROM role r
202+
WHERE m.role_id = r.id
203+
"""
204+
)
205+
206+
# Workspace invitation: role_id -> role.slug -> uppercase enum
207+
op.execute(
208+
"""
209+
UPDATE invitation i
210+
SET role = UPPER(r.slug)::workspacerole
211+
FROM role r
212+
WHERE i.role_id = r.id
213+
"""
214+
)
215+
216+
# Step 4: Make columns NOT NULL and add server_default where original schema had it
217+
op.alter_column("organization_membership", "role", nullable=False)
218+
op.alter_column(
219+
"organization_membership",
220+
"role",
221+
server_default=sa.text("'MEMBER'::orgrole"),
222+
)
223+
op.alter_column("organization_invitation", "role", nullable=False)
224+
op.alter_column("membership", "role", nullable=False)
225+
op.alter_column(
226+
"membership",
227+
"role",
228+
server_default=sa.text("'EDITOR'::workspacerole"),
229+
)
230+
op.alter_column("invitation", "role", nullable=False)
231+
232+
# Step 5: Drop role_id columns and constraints
233+
op.drop_constraint(
234+
op.f("fk_organization_invitation_role_id_role"),
235+
"organization_invitation",
236+
type_="foreignkey",
237+
)
238+
op.drop_index(
239+
op.f("ix_organization_invitation_role_id"), table_name="organization_invitation"
240+
)
241+
op.drop_column("organization_invitation", "role_id")
171242
op.drop_constraint(
172243
op.f("fk_invitation_role_id_role"), "invitation", type_="foreignkey"
173244
)

tracecat/authz/seeding.py

Lines changed: 64 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -244,30 +244,22 @@ async def seed_registry_scope(
244244
245245
Creates a scope for `action:{action_key}:execute` if it doesn't exist.
246246
Registry scopes have organization_id=NULL and source='registry'.
247+
Uses upsert (ON CONFLICT DO NOTHING) for concurrency safety.
247248
248249
Args:
249250
session: Database session
250251
action_key: The action key (e.g., "tools.okta.list_users")
251252
description: Optional description for the scope
252253
253254
Returns:
254-
The created or existing Scope, or None if upsert had no effect
255+
The created or existing Scope
255256
"""
256257
scope_name = f"action:{action_key}:execute"
258+
scope_id = uuid4()
257259

258-
# Check if scope already exists
259-
stmt = select(Scope).where(
260-
Scope.name == scope_name, Scope.organization_id.is_(None)
261-
)
262-
result = await session.execute(stmt)
263-
existing = result.scalar_one_or_none()
264-
265-
if existing:
266-
return existing
267-
268-
# Create new scope
269-
scope = Scope(
270-
id=uuid4(),
260+
# Use upsert for concurrency safety
261+
stmt = pg_insert(Scope).values(
262+
id=scope_id,
271263
name=scope_name,
272264
resource="action",
273265
action="execute",
@@ -276,10 +268,24 @@ async def seed_registry_scope(
276268
source_ref=action_key,
277269
organization_id=None,
278270
)
279-
session.add(scope)
271+
stmt = stmt.on_conflict_do_nothing(
272+
index_elements=["name"], index_where=Scope.organization_id.is_(None)
273+
)
274+
result = await session.execute(stmt)
280275
await session.flush()
281276

282-
logger.debug("Registry scope created", scope_name=scope_name, action_key=action_key)
277+
# Re-query to get the scope (whether newly inserted or already existing)
278+
select_stmt = select(Scope).where(
279+
Scope.name == scope_name, Scope.organization_id.is_(None)
280+
)
281+
select_result = await session.execute(select_stmt)
282+
scope = select_result.scalar_one_or_none()
283+
284+
if result.rowcount and result.rowcount > 0: # pyright: ignore[reportAttributeAccessIssue]
285+
logger.debug(
286+
"Registry scope created", scope_name=scope_name, action_key=action_key
287+
)
288+
283289
return scope
284290

285291

@@ -343,8 +349,9 @@ async def seed_system_roles_for_org(
343349
) -> int:
344350
"""Seed system roles (Admin, Editor, Viewer) for an organization.
345351
346-
Creates the three system roles with their associated scopes if they don't exist.
352+
Creates the system roles with their associated scopes if they don't exist.
347353
System roles are identified by their well-known slugs.
354+
Uses upsert (ON CONFLICT DO NOTHING) for concurrency safety.
348355
349356
Args:
350357
session: Database session
@@ -366,36 +373,52 @@ async def seed_system_roles_for_org(
366373

367374
roles_created = 0
368375

369-
for slug, name, description, scope_names in PRESET_ROLE_DEFINITIONS:
370-
# Check if role already exists
371-
existing_stmt = select(Role.id).where(
372-
Role.organization_id == organization_id,
373-
Role.slug == slug,
376+
# Prepare role values with pre-generated IDs
377+
role_values = []
378+
role_id_by_slug: dict[str, UUID] = {}
379+
for slug, name, description, _ in PRESET_ROLE_DEFINITIONS:
380+
role_id = uuid4()
381+
role_id_by_slug[slug] = role_id
382+
role_values.append(
383+
{
384+
"id": role_id,
385+
"name": name,
386+
"slug": slug,
387+
"description": description,
388+
"organization_id": organization_id,
389+
"created_by": None, # System-created
390+
}
374391
)
375-
existing_result = await session.execute(existing_stmt)
376-
existing_role_id = existing_result.scalar_one_or_none()
377392

378-
if existing_role_id is not None:
379-
logger.debug(
380-
"System role already exists",
393+
# Bulk upsert roles - concurrency safe
394+
role_stmt = pg_insert(Role).values(role_values)
395+
role_stmt = role_stmt.on_conflict_do_nothing(
396+
index_elements=["organization_id", "slug"]
397+
)
398+
result = await session.execute(role_stmt)
399+
roles_created = result.rowcount if result.rowcount else 0 # pyright: ignore[reportAttributeAccessIssue]
400+
401+
# Re-query to get actual role IDs (may differ if roles already existed)
402+
existing_roles_stmt = select(Role.id, Role.slug).where(
403+
Role.organization_id == organization_id,
404+
Role.slug.in_([slug for slug, _, _, _ in PRESET_ROLE_DEFINITIONS]),
405+
)
406+
existing_roles_result = await session.execute(existing_roles_stmt)
407+
actual_role_id_by_slug: dict[str | None, UUID] = {
408+
slug: role_id for role_id, slug in existing_roles_result.tuples().all()
409+
}
410+
411+
# Link scopes to roles
412+
for slug, _, _, scope_names in PRESET_ROLE_DEFINITIONS:
413+
role_id = actual_role_id_by_slug.get(slug)
414+
if role_id is None:
415+
logger.warning(
416+
"Role not found after upsert",
381417
slug=slug,
382418
organization_id=organization_id,
383419
)
384420
continue
385421

386-
# Create the role
387-
role = Role(
388-
id=uuid4(),
389-
name=name,
390-
slug=slug,
391-
description=description,
392-
organization_id=organization_id,
393-
created_by=None, # System-created
394-
)
395-
session.add(role)
396-
await session.flush() # Get the role ID
397-
398-
# Link scopes to the role
399422
role_scope_values = []
400423
for scope_name in scope_names:
401424
scope_id = scope_name_to_id.get(scope_name)
@@ -406,21 +429,13 @@ async def seed_system_roles_for_org(
406429
role_slug=slug,
407430
)
408431
continue
409-
role_scope_values.append({"role_id": role.id, "scope_id": scope_id})
432+
role_scope_values.append({"role_id": role_id, "scope_id": scope_id})
410433

411434
if role_scope_values:
412435
role_scope_stmt = pg_insert(RoleScope).values(role_scope_values)
413436
role_scope_stmt = role_scope_stmt.on_conflict_do_nothing()
414437
await session.execute(role_scope_stmt)
415438

416-
roles_created += 1
417-
logger.debug(
418-
"System role created",
419-
slug=slug,
420-
organization_id=organization_id,
421-
num_scopes=len(role_scope_values),
422-
)
423-
424439
await session.commit()
425440
logger.info(
426441
"System roles seeded for organization",

tracecat/authz/service.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ async def list_workspace_members(
7070
7171
Roles are looked up from UserRoleAssignment table.
7272
"""
73+
# Get workspace to determine organization_id
74+
workspace = await self.session.get(Workspace, workspace_id)
75+
if workspace is None:
76+
return []
77+
organization_id = workspace.organization_id
78+
7379
# Get all members of the workspace
7480
members_stmt = (
7581
select(User)
@@ -84,6 +90,7 @@ async def list_workspace_members(
8490

8591
# Get role assignments for these users in this workspace
8692
# Include both workspace-specific assignments and org-wide assignments (workspace_id IS NULL)
93+
# Filter by organization_id to ensure we only get roles from this org
8794
user_ids = [u.id for u in users]
8895
role_stmt = (
8996
select(
@@ -94,6 +101,7 @@ async def list_workspace_members(
94101
.join(RoleModel, UserRoleAssignment.role_id == RoleModel.id)
95102
.where(
96103
UserRoleAssignment.user_id.in_(user_ids),
104+
UserRoleAssignment.organization_id == organization_id,
97105
or_(
98106
UserRoleAssignment.workspace_id == workspace_id,
99107
UserRoleAssignment.workspace_id.is_(None),

tracecat/organization/service.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,22 @@ async def accept_invitation_for_user(
142142
# Shouldn't reach here, but handle gracefully
143143
raise TracecatAuthorizationError("Invitation is no longer valid")
144144

145-
# Create membership
145+
# Create membership (role is now tracked via UserRoleAssignment)
146146
membership = OrganizationMembership(
147147
user_id=user_id,
148148
organization_id=invitation.organization_id,
149-
role=invitation.role,
150149
)
151150
session.add(membership)
152151

152+
# Create role assignment from invitation
153+
role_assignment = UserRoleAssignment(
154+
organization_id=invitation.organization_id,
155+
user_id=user_id,
156+
workspace_id=None, # NULL = org-level assignment
157+
role_id=invitation.role_id,
158+
)
159+
session.add(role_assignment)
160+
153161
await session.commit()
154162
await session.refresh(membership)
155163
except TracecatAuthorizationError:
@@ -176,6 +184,8 @@ async def accept_invitation_for_user(
176184
)
177185

178186
return membership
187+
188+
179189
@dataclass(frozen=True)
180190
class MemberRoleInfo:
181191
"""Role information for an organization member."""

0 commit comments

Comments
 (0)