Skip to content

Commit 7c8c52e

Browse files
committed
refactor(rbac): batch seeding roles
1 parent 96db98a commit 7c8c52e

File tree

2 files changed

+171
-24
lines changed

2 files changed

+171
-24
lines changed

tests/unit/test_authz_seeding.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,13 @@ async def test_seed_registry_scopes(session):
7373
"core.http_request",
7474
]
7575

76+
org_count_result = await session.execute(select(func.count(Organization.id)))
77+
org_count = org_count_result.scalar_one()
78+
7679
inserted_count = await seed_registry_scopes(session, action_keys)
7780
await session.commit()
7881

79-
assert inserted_count == len(action_keys)
82+
assert inserted_count == len(action_keys) * (1 + org_count)
8083

8184
# Verify scopes exist
8285
result = await session.execute(
@@ -92,16 +95,24 @@ async def test_seed_registry_scopes(session):
9295
for key in action_keys:
9396
assert f"action:{key}:execute" in scope_names
9497

98+
custom_scope_result = await session.execute(
99+
select(func.count(Scope.id)).where(Scope.source == ScopeSource.CUSTOM)
100+
)
101+
assert custom_scope_result.scalar_one() == len(action_keys) * org_count
102+
95103

96104
@pytest.mark.anyio
97105
async def test_seed_registry_scopes_idempotent(session):
98106
"""Test that bulk seeding is idempotent."""
99107
action_keys = ["tools.test.action1", "tools.test.action2"]
100108

109+
org_count_result = await session.execute(select(func.count(Organization.id)))
110+
org_count = org_count_result.scalar_one()
111+
101112
# First seed
102113
first_count = await seed_registry_scopes(session, action_keys)
103114
await session.commit()
104-
assert first_count == 2
115+
assert first_count == len(action_keys) * (1 + org_count)
105116

106117
# Second seed
107118
second_count = await seed_registry_scopes(session, action_keys)

tracecat/authz/seeding.py

Lines changed: 158 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Seeding is idempotent - existing scopes/roles are not duplicated.
99
"""
1010

11-
from typing import NamedTuple
11+
from typing import NamedTuple, TypedDict
1212
from uuid import UUID, uuid4
1313

1414
from sqlalchemy import select
@@ -197,7 +197,18 @@ class RoleDefinition(NamedTuple):
197197
),
198198
}
199199

200-
PRESET_ROLE_SLUGS: frozenset[str] = frozenset(PRESET_ROLE_DEFINITIONS)
200+
_CUSTOM_SCOPE_BATCH_ROWS = 5_000
201+
202+
203+
class ScopeInsertRow(TypedDict):
204+
id: UUID
205+
name: str
206+
resource: str
207+
action: str
208+
description: str
209+
source: ScopeSource
210+
source_ref: str
211+
organization_id: UUID | None
201212

202213

203214
# =============================================================================
@@ -256,47 +267,172 @@ async def seed_registry_scopes(
256267
session: AsyncSession,
257268
action_keys: list[str],
258269
) -> int:
259-
"""Seed registry action scopes in bulk.
270+
"""Seed registry scopes.
271+
272+
Current behavior has two explicit steps:
273+
1. Seed platform registry scopes.
274+
2. Seed custom registry scopes
275+
"""
276+
platform_inserted = await _seed_registry_scopes(
277+
session,
278+
action_keys,
279+
source=ScopeSource.PLATFORM,
280+
organization_id=None,
281+
)
282+
custom_inserted = await _seed_custom_registry_scopes(session, action_keys)
283+
return platform_inserted + custom_inserted
284+
285+
286+
async def seed_platform_registry_scopes(
287+
session: AsyncSession,
288+
action_keys: list[str],
289+
) -> int:
290+
"""Seed platform registry action scopes in bulk."""
291+
return await _seed_registry_scopes(
292+
session,
293+
action_keys,
294+
source=ScopeSource.PLATFORM,
295+
organization_id=None,
296+
)
260297

261-
Creates scopes for all action keys that don't already exist.
262-
Uses PostgreSQL upsert for efficiency.
298+
299+
async def _seed_custom_registry_scopes(
300+
session: AsyncSession,
301+
action_keys: list[str],
302+
) -> int:
303+
"""Seed custom registry scopes for all organizations using chunked upserts."""
304+
if not action_keys:
305+
return 0
306+
307+
org_stmt = select(Organization.id)
308+
org_result = await session.execute(org_stmt)
309+
org_ids = [org_id for (org_id,) in org_result.tuples().all()]
310+
if not org_ids:
311+
return 0
312+
313+
logger.info(
314+
"Seeding registry scopes",
315+
num_actions=len(action_keys),
316+
source=ScopeSource.CUSTOM.value,
317+
num_organizations=len(org_ids),
318+
)
319+
320+
inserted_count = 0
321+
batch_values: list[ScopeInsertRow] = []
322+
for org_id in org_ids:
323+
for key in action_keys:
324+
batch_values.append(
325+
_build_registry_scope_row(
326+
action_key=key,
327+
source=ScopeSource.CUSTOM,
328+
organization_id=org_id,
329+
)
330+
)
331+
if len(batch_values) >= _CUSTOM_SCOPE_BATCH_ROWS:
332+
inserted_count += await _upsert_registry_scope_rows(
333+
session=session,
334+
values=batch_values,
335+
source=ScopeSource.CUSTOM,
336+
)
337+
batch_values.clear()
338+
339+
if batch_values:
340+
inserted_count += await _upsert_registry_scope_rows(
341+
session=session,
342+
values=batch_values,
343+
source=ScopeSource.CUSTOM,
344+
)
345+
346+
logger.info(
347+
"Registry scopes seeded",
348+
inserted=inserted_count,
349+
total=len(org_ids) * len(action_keys),
350+
source=ScopeSource.CUSTOM.value,
351+
)
352+
return inserted_count
353+
354+
355+
async def _seed_registry_scopes(
356+
session: AsyncSession,
357+
action_keys: list[str],
358+
*,
359+
source: ScopeSource,
360+
organization_id: UUID | None,
361+
) -> int:
362+
"""Seed registry action scopes with explicit source and ownership.
263363
264364
Args:
265365
session: Database session
266366
action_keys: List of action keys (e.g., ["tools.okta.list_users", "core.http_request"])
367+
source: Scope ownership category (platform or custom).
368+
organization_id: Target organization for custom scopes, None for platform.
267369
268370
Returns:
269-
Number of scopes inserted
371+
Number of scopes inserted.
270372
"""
271373
if not action_keys:
272374
return 0
273375

274-
logger.info("Seeding registry scopes", num_actions=len(action_keys))
376+
logger.info(
377+
"Seeding registry scopes",
378+
num_actions=len(action_keys),
379+
)
275380

276381
values = [
277-
{
278-
"id": uuid4(),
279-
"name": f"action:{key}:execute",
280-
"resource": "action",
281-
"action": "execute",
282-
"description": f"Execute {key} action",
283-
"source": ScopeSource.PLATFORM,
284-
"source_ref": key,
285-
"organization_id": None,
286-
}
382+
_build_registry_scope_row(
383+
action_key=key, source=source, organization_id=organization_id
384+
)
287385
for key in action_keys
288386
]
289387

290-
stmt = pg_insert(Scope).values(values)
291-
stmt = stmt.on_conflict_do_nothing(
292-
index_elements=["name"], index_where=Scope.organization_id.is_(None)
388+
return await _upsert_registry_scope_rows(
389+
session=session,
390+
values=values,
391+
source=source,
293392
)
294393

394+
395+
def _build_registry_scope_row(
396+
*, action_key: str, source: ScopeSource, organization_id: UUID | None
397+
) -> ScopeInsertRow:
398+
"""Build a single scope insert row for a registry action key."""
399+
return {
400+
"id": uuid4(),
401+
"name": f"action:{action_key}:execute",
402+
"resource": "action",
403+
"action": "execute",
404+
"description": f"Execute {action_key} action",
405+
"source": source,
406+
"source_ref": action_key,
407+
"organization_id": organization_id,
408+
}
409+
410+
411+
async def _upsert_registry_scope_rows(
412+
*,
413+
session: AsyncSession,
414+
values: list[ScopeInsertRow],
415+
source: ScopeSource,
416+
) -> int:
417+
"""Insert scope rows with conflict handling for platform vs org-scoped scopes."""
418+
if not values:
419+
return 0
420+
stmt = pg_insert(Scope).values(values)
421+
if source == ScopeSource.PLATFORM:
422+
stmt = stmt.on_conflict_do_nothing(
423+
index_elements=["name"], index_where=Scope.organization_id.is_(None)
424+
)
425+
else:
426+
stmt = stmt.on_conflict_do_nothing(index_elements=["organization_id", "name"])
427+
295428
result = await session.execute(stmt)
296-
inserted_count = result.rowcount if result.rowcount else 0 # pyright: ignore[reportAttributeAccessIssue]
429+
inserted_count = result.rowcount or 0 # pyright: ignore[reportAttributeAccessIssue]
297430

298431
logger.info(
299-
"Registry scopes seeded", inserted=inserted_count, total=len(action_keys)
432+
"Registry scopes seeded",
433+
inserted=inserted_count,
434+
total=len(values),
435+
source=source.value,
300436
)
301437
return inserted_count
302438

0 commit comments

Comments
 (0)