Skip to content

Commit 7d34d43

Browse files
committed
fix(invitations): create user role assignment on invite
1 parent c25def1 commit 7d34d43

File tree

6 files changed

+161
-64
lines changed

6 files changed

+161
-64
lines changed

tests/unit/test_authz_seeding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async def test_seed_system_scopes(session):
2525
# Verify scopes exist in database
2626
result = await session.execute(
2727
select(Scope).where(
28-
Scope.source == ScopeSource.SYSTEM,
28+
Scope.source == ScopeSource.PLATFORM,
2929
Scope.organization_id.is_(None),
3030
)
3131
)
@@ -52,7 +52,7 @@ async def test_seed_system_scopes_idempotent(session):
5252
# Verify count is still the same
5353
result = await session.execute(
5454
select(Scope).where(
55-
Scope.source == ScopeSource.SYSTEM,
55+
Scope.source == ScopeSource.PLATFORM,
5656
Scope.organization_id.is_(None),
5757
)
5858
)
@@ -72,7 +72,7 @@ async def test_seed_registry_scope(session):
7272
assert scope.name == f"action:{action_key}:execute"
7373
assert scope.resource == "action"
7474
assert scope.action == "execute"
75-
assert scope.source == ScopeSource.REGISTRY
75+
assert scope.source == ScopeSource.PLATFORM
7676
assert scope.source_ref == action_key
7777
assert scope.organization_id is None
7878

@@ -112,7 +112,7 @@ async def test_seed_registry_scopes_bulk(session):
112112
# Verify scopes exist
113113
result = await session.execute(
114114
select(Scope).where(
115-
Scope.source == ScopeSource.REGISTRY,
115+
Scope.source == ScopeSource.PLATFORM,
116116
Scope.organization_id.is_(None),
117117
)
118118
)

tracecat/authz/controls.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import asyncio
22
import functools
33
import re
4+
import warnings
45
from collections.abc import Callable, Coroutine
56
from fnmatch import fnmatch
67
from typing import Any, Protocol, TypeVar, cast, runtime_checkable
78

8-
from tracecat.auth.types import Role
9+
from tracecat.auth.types import AccessLevel, Role
910
from tracecat.authz.enums import OrgRole, WorkspaceRole
1011
from tracecat.contexts import ctx_role
1112
from tracecat.exceptions import ScopeDeniedError, TracecatAuthorizationError
@@ -122,6 +123,54 @@ class HasRole(Protocol):
122123
role: Role
123124

124125

126+
def require_access_level(level: AccessLevel) -> Callable[[T], T]:
127+
"""Decorator that protects a `Service` method with a minimum access level requirement.
128+
129+
If the caller does not have at least the required access level, a TracecatAuthorizationError is raised.
130+
131+
.. deprecated::
132+
Use `@require_scope` instead. This decorator will be removed in a future version.
133+
"""
134+
warnings.warn(
135+
"require_access_level is deprecated, use require_scope instead",
136+
DeprecationWarning,
137+
stacklevel=2,
138+
)
139+
140+
def check(self: HasRole):
141+
if not hasattr(self, "role"):
142+
raise AttributeError("Service must have a 'role' attribute")
143+
144+
if not isinstance(self.role, Role):
145+
raise ValueError("Invalid role type")
146+
147+
user_role = self.role
148+
if user_role.access_level < level:
149+
raise TracecatAuthorizationError(
150+
f"User does not have required access level: {level.name}"
151+
)
152+
153+
def decorator(fn: T) -> T:
154+
if asyncio.iscoroutinefunction(fn):
155+
156+
@functools.wraps(fn)
157+
async def async_wrapper(self: HasRole, *args, **kwargs):
158+
check(self)
159+
return await fn(self, *args, **kwargs)
160+
161+
return cast(T, async_wrapper)
162+
else:
163+
164+
@functools.wraps(fn)
165+
def sync_wrapper(self: HasRole, *args, **kwargs):
166+
check(self)
167+
return fn(self, *args, **kwargs)
168+
169+
return cast(T, sync_wrapper)
170+
171+
return decorator
172+
173+
125174
def require_org_role(*roles: OrgRole) -> Callable[[T], T]:
126175
"""Decorator that protects a Service method with an org role requirement.
127176

tracecat/authz/seeding.py

Lines changed: 67 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ async def seed_system_scopes(session: AsyncSession) -> int:
210210
"resource": resource,
211211
"action": action,
212212
"description": description,
213-
"source": ScopeSource.SYSTEM,
213+
"source": ScopeSource.PLATFORM,
214214
"source_ref": None,
215215
"organization_id": None,
216216
}
@@ -244,42 +244,48 @@ async def seed_registry_scope(
244244
245245
Creates a scope for `action:{action_key}:execute` if it doesn't exist.
246246
Registry scopes have organization_id=NULL and source='registry'.
247+
Uses upsert (ON CONFLICT DO NOTHING) for concurrency safety.
247248
248249
Args:
249250
session: Database session
250251
action_key: The action key (e.g., "tools.okta.list_users")
251252
description: Optional description for the scope
252253
253254
Returns:
254-
The created or existing Scope, or None if upsert had no effect
255+
The created or existing Scope
255256
"""
256257
scope_name = f"action:{action_key}:execute"
258+
scope_id = uuid4()
257259

258-
# Check if scope already exists
259-
stmt = select(Scope).where(
260-
Scope.name == scope_name, Scope.organization_id.is_(None)
261-
)
262-
result = await session.execute(stmt)
263-
existing = result.scalar_one_or_none()
264-
265-
if existing:
266-
return existing
267-
268-
# Create new scope
269-
scope = Scope(
270-
id=uuid4(),
260+
# Use upsert for concurrency safety
261+
stmt = pg_insert(Scope).values(
262+
id=scope_id,
271263
name=scope_name,
272264
resource="action",
273265
action="execute",
274266
description=description or f"Execute {action_key} action",
275-
source=ScopeSource.REGISTRY,
267+
source=ScopeSource.PLATFORM,
276268
source_ref=action_key,
277269
organization_id=None,
278270
)
279-
session.add(scope)
271+
stmt = stmt.on_conflict_do_nothing(
272+
index_elements=["name"], index_where=Scope.organization_id.is_(None)
273+
)
274+
result = await session.execute(stmt)
280275
await session.flush()
281276

282-
logger.debug("Registry scope created", scope_name=scope_name, action_key=action_key)
277+
# Re-query to get the scope (whether newly inserted or already existing)
278+
select_stmt = select(Scope).where(
279+
Scope.name == scope_name, Scope.organization_id.is_(None)
280+
)
281+
select_result = await session.execute(select_stmt)
282+
scope = select_result.scalar_one_or_none()
283+
284+
if result.rowcount and result.rowcount > 0: # pyright: ignore[reportAttributeAccessIssue]
285+
logger.debug(
286+
"Registry scope created", scope_name=scope_name, action_key=action_key
287+
)
288+
283289
return scope
284290

285291

@@ -311,7 +317,7 @@ async def seed_registry_scopes_bulk(
311317
"resource": "action",
312318
"action": "execute",
313319
"description": f"Execute {key} action",
314-
"source": ScopeSource.REGISTRY,
320+
"source": ScopeSource.PLATFORM,
315321
"source_ref": key,
316322
"organization_id": None,
317323
}
@@ -343,8 +349,9 @@ async def seed_system_roles_for_org(
343349
) -> int:
344350
"""Seed system roles (Admin, Editor, Viewer) for an organization.
345351
346-
Creates the three system roles with their associated scopes if they don't exist.
352+
Creates the system roles with their associated scopes if they don't exist.
347353
System roles are identified by their well-known slugs.
354+
Uses upsert (ON CONFLICT DO NOTHING) for concurrency safety.
348355
349356
Args:
350357
session: Database session
@@ -366,36 +373,52 @@ async def seed_system_roles_for_org(
366373

367374
roles_created = 0
368375

369-
for slug, name, description, scope_names in PRESET_ROLE_DEFINITIONS:
370-
# Check if role already exists
371-
existing_stmt = select(Role.id).where(
372-
Role.organization_id == organization_id,
373-
Role.slug == slug,
376+
# Prepare role values with pre-generated IDs
377+
role_values = []
378+
role_id_by_slug: dict[str, UUID] = {}
379+
for slug, name, description, _ in PRESET_ROLE_DEFINITIONS:
380+
role_id = uuid4()
381+
role_id_by_slug[slug] = role_id
382+
role_values.append(
383+
{
384+
"id": role_id,
385+
"name": name,
386+
"slug": slug,
387+
"description": description,
388+
"organization_id": organization_id,
389+
"created_by": None, # System-created
390+
}
374391
)
375-
existing_result = await session.execute(existing_stmt)
376-
existing_role_id = existing_result.scalar_one_or_none()
377392

378-
if existing_role_id is not None:
379-
logger.debug(
380-
"System role already exists",
393+
# Bulk upsert roles - concurrency safe
394+
role_stmt = pg_insert(Role).values(role_values)
395+
role_stmt = role_stmt.on_conflict_do_nothing(
396+
index_elements=["organization_id", "slug"]
397+
)
398+
result = await session.execute(role_stmt)
399+
roles_created = result.rowcount if result.rowcount else 0 # pyright: ignore[reportAttributeAccessIssue]
400+
401+
# Re-query to get actual role IDs (may differ if roles already existed)
402+
existing_roles_stmt = select(Role.id, Role.slug).where(
403+
Role.organization_id == organization_id,
404+
Role.slug.in_([slug for slug, _, _, _ in PRESET_ROLE_DEFINITIONS]),
405+
)
406+
existing_roles_result = await session.execute(existing_roles_stmt)
407+
actual_role_id_by_slug: dict[str | None, UUID] = {
408+
slug: role_id for role_id, slug in existing_roles_result.tuples().all()
409+
}
410+
411+
# Link scopes to roles
412+
for slug, _, _, scope_names in PRESET_ROLE_DEFINITIONS:
413+
role_id = actual_role_id_by_slug.get(slug)
414+
if role_id is None:
415+
logger.warning(
416+
"Role not found after upsert",
381417
slug=slug,
382418
organization_id=organization_id,
383419
)
384420
continue
385421

386-
# Create the role
387-
role = Role(
388-
id=uuid4(),
389-
name=name,
390-
slug=slug,
391-
description=description,
392-
organization_id=organization_id,
393-
created_by=None, # System-created
394-
)
395-
session.add(role)
396-
await session.flush() # Get the role ID
397-
398-
# Link scopes to the role
399422
role_scope_values = []
400423
for scope_name in scope_names:
401424
scope_id = scope_name_to_id.get(scope_name)
@@ -406,21 +429,13 @@ async def seed_system_roles_for_org(
406429
role_slug=slug,
407430
)
408431
continue
409-
role_scope_values.append({"role_id": role.id, "scope_id": scope_id})
432+
role_scope_values.append({"role_id": role_id, "scope_id": scope_id})
410433

411434
if role_scope_values:
412435
role_scope_stmt = pg_insert(RoleScope).values(role_scope_values)
413436
role_scope_stmt = role_scope_stmt.on_conflict_do_nothing()
414437
await session.execute(role_scope_stmt)
415438

416-
roles_created += 1
417-
logger.debug(
418-
"System role created",
419-
slug=slug,
420-
organization_id=organization_id,
421-
num_scopes=len(role_scope_values),
422-
)
423-
424439
await session.commit()
425440
logger.info(
426441
"System roles seeded for organization",

tracecat/authz/service.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,12 @@ async def list_workspace_members(
7070
7171
Roles are looked up from UserRoleAssignment table.
7272
"""
73+
# Get workspace to determine organization_id
74+
workspace = await self.session.get(Workspace, workspace_id)
75+
if workspace is None:
76+
return []
77+
organization_id = workspace.organization_id
78+
7379
# Get all members of the workspace
7480
members_stmt = (
7581
select(User)
@@ -84,6 +90,7 @@ async def list_workspace_members(
8490

8591
# Get role assignments for these users in this workspace
8692
# Include both workspace-specific assignments and org-wide assignments (workspace_id IS NULL)
93+
# Filter by organization_id to ensure we only get roles from this org
8794
user_ids = [u.id for u in users]
8895
role_stmt = (
8996
select(
@@ -94,6 +101,7 @@ async def list_workspace_members(
94101
.join(RoleModel, UserRoleAssignment.role_id == RoleModel.id)
95102
.where(
96103
UserRoleAssignment.user_id.in_(user_ids),
104+
UserRoleAssignment.organization_id == organization_id,
97105
or_(
98106
UserRoleAssignment.workspace_id == workspace_id,
99107
UserRoleAssignment.workspace_id.is_(None),

tracecat/organization/service.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
get_user_db_context,
2323
get_user_manager_context,
2424
)
25-
from tracecat.authz.controls import require_access_level
25+
from tracecat.authz.controls import require_access_level, require_org_role
2626
from tracecat.authz.enums import OrgRole # Still used for privilege escalation check
2727
from tracecat.db.models import (
2828
AccessToken,
@@ -142,14 +142,22 @@ async def accept_invitation_for_user(
142142
# Shouldn't reach here, but handle gracefully
143143
raise TracecatAuthorizationError("Invitation is no longer valid")
144144

145-
# Create membership
145+
# Create membership (role is now tracked via UserRoleAssignment)
146146
membership = OrganizationMembership(
147147
user_id=user_id,
148148
organization_id=invitation.organization_id,
149-
role=invitation.role,
150149
)
151150
session.add(membership)
152151

152+
# Create role assignment from invitation
153+
role_assignment = UserRoleAssignment(
154+
organization_id=invitation.organization_id,
155+
user_id=user_id,
156+
workspace_id=None, # NULL = org-level assignment
157+
role_id=invitation.role_id,
158+
)
159+
session.add(role_assignment)
160+
153161
await session.commit()
154162
await session.refresh(membership)
155163
except TracecatAuthorizationError:
@@ -176,6 +184,8 @@ async def accept_invitation_for_user(
176184
)
177185

178186
return membership
187+
188+
179189
@dataclass(frozen=True)
180190
class MemberRoleInfo:
181191
"""Role information for an organization member."""

0 commit comments

Comments
 (0)