Skip to content

Commit ed1dadb

Browse files
committed
refactor(rbac): drop test singular scope seed
1 parent 7197c57 commit ed1dadb

File tree

3 files changed

+147
-110
lines changed

3 files changed

+147
-110
lines changed

tests/unit/test_authz_seeding.py

Lines changed: 77 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
"""Tests for RBAC scope seeding."""
22

3+
from uuid import uuid4
4+
35
import pytest
4-
from sqlalchemy import select
6+
from sqlalchemy import func, select
57

68
from tracecat.authz.enums import ScopeSource
79
from tracecat.authz.seeding import (
10+
PRESET_ROLE_DEFINITIONS,
811
SYSTEM_SCOPE_DEFINITIONS,
9-
seed_registry_scope,
10-
seed_registry_scopes_bulk,
12+
seed_registry_scopes,
13+
seed_system_roles_for_all_orgs,
1114
seed_system_scopes,
1215
)
13-
from tracecat.db.models import Scope
16+
from tracecat.db.models import Organization, Role, RoleScope, Scope
1417

1518

1619
@pytest.mark.anyio
@@ -61,41 +64,7 @@ async def test_seed_system_scopes_idempotent(session):
6164

6265

6366
@pytest.mark.anyio
64-
async def test_seed_registry_scope(session):
65-
"""Test seeding a single registry scope."""
66-
action_key = "tools.test_integration.test_action"
67-
68-
scope = await seed_registry_scope(session, action_key, "Test action scope")
69-
await session.commit()
70-
71-
assert scope is not None
72-
assert scope.name == f"action:{action_key}:execute"
73-
assert scope.resource == "action"
74-
assert scope.action == "execute"
75-
assert scope.source == ScopeSource.PLATFORM
76-
assert scope.source_ref == action_key
77-
assert scope.organization_id is None
78-
79-
80-
@pytest.mark.anyio
81-
async def test_seed_registry_scope_idempotent(session):
82-
"""Test that seeding the same registry scope twice returns existing scope."""
83-
action_key = "tools.test_integration.test_action_idempotent"
84-
85-
# Seed first time
86-
scope1 = await seed_registry_scope(session, action_key)
87-
await session.commit()
88-
89-
# Seed second time
90-
scope2 = await seed_registry_scope(session, action_key)
91-
92-
assert scope1 is not None
93-
assert scope2 is not None
94-
assert scope1.id == scope2.id
95-
96-
97-
@pytest.mark.anyio
98-
async def test_seed_registry_scopes_bulk(session):
67+
async def test_seed_registry_scopes(session):
9968
"""Test bulk seeding of registry scopes."""
10069
action_keys = [
10170
"tools.okta.list_users",
@@ -104,7 +73,7 @@ async def test_seed_registry_scopes_bulk(session):
10473
"core.http_request",
10574
]
10675

107-
inserted_count = await seed_registry_scopes_bulk(session, action_keys)
76+
inserted_count = await seed_registry_scopes(session, action_keys)
10877
await session.commit()
10978

11079
assert inserted_count == len(action_keys)
@@ -125,28 +94,91 @@ async def test_seed_registry_scopes_bulk(session):
12594

12695

12796
@pytest.mark.anyio
128-
async def test_seed_registry_scopes_bulk_idempotent(session):
97+
async def test_seed_registry_scopes_idempotent(session):
12998
"""Test that bulk seeding is idempotent."""
13099
action_keys = ["tools.test.action1", "tools.test.action2"]
131100

132101
# First seed
133-
first_count = await seed_registry_scopes_bulk(session, action_keys)
102+
first_count = await seed_registry_scopes(session, action_keys)
134103
await session.commit()
135104
assert first_count == 2
136105

137106
# Second seed
138-
second_count = await seed_registry_scopes_bulk(session, action_keys)
107+
second_count = await seed_registry_scopes(session, action_keys)
139108
await session.commit()
140109
assert second_count == 0
141110

142111

143112
@pytest.mark.anyio
144-
async def test_seed_registry_scopes_bulk_empty(session):
113+
async def test_seed_registry_scopes_empty(session):
145114
"""Test bulk seeding with empty list."""
146-
inserted_count = await seed_registry_scopes_bulk(session, [])
115+
inserted_count = await seed_registry_scopes(session, [])
147116
assert inserted_count == 0
148117

149118

119+
@pytest.mark.anyio
120+
async def test_seed_system_roles_for_all_orgs_creates_roles_and_links(session):
121+
"""Seed preset roles for all orgs and link expected system scopes."""
122+
# Ensure scope IDs exist for role->scope links.
123+
await seed_system_scopes(session)
124+
125+
# Add an extra org so the function processes multiple orgs in one call.
126+
extra_org = Organization(
127+
id=uuid4(),
128+
name="Extra test org",
129+
slug=f"extra-test-org-{uuid4().hex[:8]}",
130+
is_active=True,
131+
)
132+
session.add(extra_org)
133+
await session.flush()
134+
135+
# Capture target org IDs in this isolated session.
136+
org_result = await session.execute(select(Organization.id))
137+
org_ids = {org_id for (org_id,) in org_result.tuples().all()}
138+
139+
created_by_org = await seed_system_roles_for_all_orgs(session)
140+
141+
for org_id in org_ids:
142+
assert created_by_org[org_id] == len(PRESET_ROLE_DEFINITIONS)
143+
144+
roles_result = await session.execute(
145+
select(Role.id, Role.organization_id, Role.slug).where(
146+
Role.organization_id.in_(org_ids),
147+
Role.slug.in_(PRESET_ROLE_DEFINITIONS),
148+
)
149+
)
150+
roles = roles_result.tuples().all()
151+
assert len(roles) == len(org_ids) * len(PRESET_ROLE_DEFINITIONS)
152+
153+
role_scope_count_stmt = (
154+
select(Role.slug, func.count(RoleScope.scope_id))
155+
.select_from(Role)
156+
.join(RoleScope, RoleScope.role_id == Role.id)
157+
.where(
158+
Role.organization_id.in_(org_ids),
159+
Role.slug.in_(PRESET_ROLE_DEFINITIONS),
160+
)
161+
.group_by(Role.slug)
162+
)
163+
role_scope_count_result = await session.execute(role_scope_count_stmt)
164+
role_scope_counts = dict(role_scope_count_result.tuples().all())
165+
expected_org_count = len(org_ids)
166+
for slug, role_def in PRESET_ROLE_DEFINITIONS.items():
167+
assert role_scope_counts[slug] == len(role_def.scopes) * expected_org_count
168+
169+
170+
@pytest.mark.anyio
171+
async def test_seed_system_roles_for_all_orgs_idempotent(session):
172+
"""Running role seeding twice should not create duplicate roles."""
173+
await seed_system_scopes(session)
174+
175+
first = await seed_system_roles_for_all_orgs(session)
176+
second = await seed_system_roles_for_all_orgs(session)
177+
178+
assert sum(first.values()) > 0
179+
assert all(created == 0 for created in second.values())
180+
181+
150182
@pytest.mark.anyio
151183
async def test_system_scope_definitions_format(session):
152184
"""Test that all system scope definitions follow the expected format."""

tracecat/authz/seeding.py

Lines changed: 68 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ async def seed_system_scopes(session: AsyncSession) -> int:
243243
result = await session.execute(stmt)
244244
await session.commit()
245245

246-
inserted_count = result.rowcount if result.rowcount else 0 # pyright: ignore[reportAttributeAccessIssue]
246+
inserted_count = result.rowcount or 0 # pyright: ignore[reportAttributeAccessIssue]
247247
logger.info(
248248
"System scopes seeded",
249249
inserted=inserted_count,
@@ -252,61 +252,7 @@ async def seed_system_scopes(session: AsyncSession) -> int:
252252
return inserted_count
253253

254254

255-
async def seed_registry_scope(
256-
session: AsyncSession,
257-
action_key: str,
258-
description: str | None = None,
259-
) -> Scope | None:
260-
"""Seed a single registry action scope.
261-
262-
Creates a scope for `action:{action_key}:execute` if it doesn't exist.
263-
Registry scopes have organization_id=NULL and source='registry'.
264-
Uses upsert (ON CONFLICT DO NOTHING) for concurrency safety.
265-
266-
Args:
267-
session: Database session
268-
action_key: The action key (e.g., "tools.okta.list_users")
269-
description: Optional description for the scope
270-
271-
Returns:
272-
The created or existing Scope
273-
"""
274-
scope_name = f"action:{action_key}:execute"
275-
scope_id = uuid4()
276-
277-
# Use upsert for concurrency safety
278-
stmt = pg_insert(Scope).values(
279-
id=scope_id,
280-
name=scope_name,
281-
resource="action",
282-
action="execute",
283-
description=description or f"Execute {action_key} action",
284-
source=ScopeSource.PLATFORM,
285-
source_ref=action_key,
286-
organization_id=None,
287-
)
288-
stmt = stmt.on_conflict_do_nothing(
289-
index_elements=["name"], index_where=Scope.organization_id.is_(None)
290-
)
291-
result = await session.execute(stmt)
292-
await session.flush()
293-
294-
# Re-query to get the scope (whether newly inserted or already existing)
295-
select_stmt = select(Scope).where(
296-
Scope.name == scope_name, Scope.organization_id.is_(None)
297-
)
298-
select_result = await session.execute(select_stmt)
299-
scope = select_result.scalar_one_or_none()
300-
301-
if result.rowcount and result.rowcount > 0: # pyright: ignore[reportAttributeAccessIssue]
302-
logger.debug(
303-
"Registry scope created", scope_name=scope_name, action_key=action_key
304-
)
305-
306-
return scope
307-
308-
309-
async def seed_registry_scopes_bulk(
255+
async def seed_registry_scopes(
310256
session: AsyncSession,
311257
action_keys: list[str],
312258
) -> int:
@@ -413,7 +359,7 @@ async def seed_system_roles_for_org(
413359
index_elements=["organization_id", "slug"]
414360
)
415361
result = await session.execute(role_stmt)
416-
roles_created = result.rowcount if result.rowcount else 0 # pyright: ignore[reportAttributeAccessIssue]
362+
roles_created = result.rowcount or 0 # pyright: ignore[reportAttributeAccessIssue]
417363

418364
# Re-query to get actual role IDs (may differ if roles already existed)
419365
existing_roles_stmt = select(Role.id, Role.slug).where(
@@ -490,14 +436,73 @@ async def seed_system_roles_for_all_orgs(session: AsyncSession) -> dict[UUID, in
490436
logger.info("No organizations found, skipping system role seeding")
491437
return {}
492438

493-
results: dict[UUID, int] = {}
494-
total_created = 0
495-
439+
# 1) Upsert all preset roles for all orgs in one bulk query.
440+
role_values = []
496441
for org_id in org_ids:
497-
roles_created = await seed_system_roles_for_org(session, org_id)
498-
results[org_id] = roles_created
499-
total_created += roles_created
442+
for slug, role_def in PRESET_ROLE_DEFINITIONS.items():
443+
role_values.append(
444+
{
445+
"id": uuid4(),
446+
"name": role_def.name,
447+
"slug": slug,
448+
"description": role_def.description,
449+
"organization_id": org_id,
450+
"created_by": None,
451+
}
452+
)
453+
454+
role_insert_stmt = pg_insert(Role).values(role_values)
455+
role_insert_stmt = role_insert_stmt.on_conflict_do_nothing(
456+
index_elements=["organization_id", "slug"]
457+
).returning(Role.organization_id)
458+
role_insert_result = await session.execute(role_insert_stmt)
459+
460+
# Return shape remains {org_id: roles_created_for_org}.
461+
results: dict[UUID, int] = dict.fromkeys(org_ids, 0)
462+
for (organization_id,) in role_insert_result.tuples().all():
463+
results[organization_id] += 1
464+
465+
# 2) Fetch all relevant global scopes once.
466+
scope_stmt = select(Scope.id, Scope.name).where(Scope.organization_id.is_(None))
467+
scope_result = await session.execute(scope_stmt)
468+
scope_id_by_name: dict[str, UUID] = {
469+
scope_name: scope_id for scope_id, scope_name in scope_result.tuples().all()
470+
}
471+
472+
# 3) Fetch all preset roles for those orgs and bulk insert role-scope links.
473+
role_stmt = select(Role.id, Role.slug).where(
474+
Role.organization_id.in_(org_ids),
475+
Role.slug.in_(PRESET_ROLE_DEFINITIONS),
476+
)
477+
role_result = await session.execute(role_stmt)
478+
479+
role_scope_values = []
480+
for role_id, role_slug in role_result.tuples().all():
481+
if role_slug is None:
482+
continue
483+
role_def = PRESET_ROLE_DEFINITIONS.get(role_slug)
484+
if role_def is None:
485+
continue
486+
for scope_name in role_def.scopes:
487+
scope_id = scope_id_by_name.get(scope_name)
488+
if scope_id is None:
489+
logger.warning(
490+
"Scope not found for system role",
491+
scope_name=scope_name,
492+
role_slug=role_slug,
493+
)
494+
continue
495+
role_scope_values.append({"role_id": role_id, "scope_id": scope_id})
496+
497+
if role_scope_values:
498+
role_scope_stmt = pg_insert(RoleScope).values(role_scope_values)
499+
role_scope_stmt = role_scope_stmt.on_conflict_do_nothing(
500+
index_elements=["role_id", "scope_id"]
501+
)
502+
await session.execute(role_scope_stmt)
503+
await session.commit()
500504

505+
total_created = sum(results.values())
501506
logger.info(
502507
"System roles seeded for all organizations",
503508
num_orgs=len(org_ids),

tracecat/registry/sync/jobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sqlalchemy.exc import DBAPIError
1313
from sqlalchemy.ext.asyncio import AsyncSession
1414

15-
from tracecat.authz.seeding import seed_registry_scopes_bulk
15+
from tracecat.authz.seeding import seed_registry_scopes
1616
from tracecat.db.engine import get_async_session_context_manager
1717
from tracecat.db.locks import (
1818
derive_lock_key_from_parts,
@@ -195,7 +195,7 @@ async def _seed_registry_scopes(
195195
action_keys = [f"{action.namespace}.{action.name}" for action in actions]
196196

197197
try:
198-
inserted = await seed_registry_scopes_bulk(session, action_keys)
198+
inserted = await seed_registry_scopes(session, action_keys)
199199
await session.commit()
200200
logger.info("Registry scopes seeded", inserted=inserted, total=len(action_keys))
201201
except DBAPIError as e:

0 commit comments

Comments
 (0)