Skip to content

Commit c4ba9cc

Browse files
committed
refactor(rbac): batch seeding roles
1 parent f294c23 commit c4ba9cc

File tree

2 files changed

+150
-29
lines changed

2 files changed

+150
-29
lines changed

tests/unit/test_authz_seeding.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,13 @@ async def test_seed_registry_scopes(session):
7373
"core.http_request",
7474
]
7575

76+
org_count_result = await session.execute(select(func.count(Organization.id)))
77+
org_count = org_count_result.scalar_one()
78+
7679
inserted_count = await seed_registry_scopes(session, action_keys)
7780
await session.commit()
7881

79-
assert inserted_count == len(action_keys)
82+
assert inserted_count == len(action_keys) * (1 + org_count)
8083

8184
# Verify scopes exist
8285
result = await session.execute(
@@ -92,16 +95,24 @@ async def test_seed_registry_scopes(session):
9295
for key in action_keys:
9396
assert f"action:{key}:execute" in scope_names
9497

98+
custom_scope_result = await session.execute(
99+
select(func.count(Scope.id)).where(Scope.source == ScopeSource.CUSTOM)
100+
)
101+
assert custom_scope_result.scalar_one() == len(action_keys) * org_count
102+
95103

96104
@pytest.mark.anyio
97105
async def test_seed_registry_scopes_idempotent(session):
98106
"""Test that bulk seeding is idempotent."""
99107
action_keys = ["tools.test.action1", "tools.test.action2"]
100108

109+
org_count_result = await session.execute(select(func.count(Organization.id)))
110+
org_count = org_count_result.scalar_one()
111+
101112
# First seed
102113
first_count = await seed_registry_scopes(session, action_keys)
103114
await session.commit()
104-
assert first_count == 2
115+
assert first_count == len(action_keys) * (1 + org_count)
105116

106117
# Second seed
107118
second_count = await seed_registry_scopes(session, action_keys)

tracecat/authz/seeding.py

Lines changed: 137 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Seeding is idempotent - existing scopes/roles are not duplicated.
99
"""
1010

11-
from typing import NamedTuple
11+
from typing import NamedTuple, TypedDict
1212
from uuid import UUID, uuid4
1313

1414
from sqlalchemy import select
@@ -197,7 +197,18 @@ class RoleDefinition(NamedTuple):
197197
),
198198
}
199199

200-
PRESET_ROLE_SLUGS: frozenset[str] = frozenset(PRESET_ROLE_DEFINITIONS)
200+
_CUSTOM_SCOPE_BATCH_ROWS = 5_000
201+
202+
203+
class ScopeInsertRow(TypedDict):
204+
id: UUID
205+
name: str
206+
resource: str
207+
action: str
208+
description: str
209+
source: ScopeSource
210+
source_ref: str
211+
organization_id: UUID | None
201212

202213

203214
# =============================================================================
@@ -256,47 +267,146 @@ async def seed_registry_scopes(
256267
session: AsyncSession,
257268
action_keys: list[str],
258269
) -> int:
259-
"""Seed registry action scopes in bulk.
270+
"""Seed registry scopes.
260271
261-
Creates scopes for all action keys that don't already exist.
262-
Uses PostgreSQL upsert for efficiency.
272+
Current behavior has two explicit steps:
273+
1. Seed platform registry scopes.
274+
2. Seed custom registry scopes
275+
"""
276+
platform_inserted = await _seed_platform_registry_scopes(
277+
session,
278+
action_keys,
279+
)
280+
custom_inserted = await _seed_custom_registry_scopes(session, action_keys)
281+
return platform_inserted + custom_inserted
263282

264-
Args:
265-
session: Database session
266-
action_keys: List of action keys (e.g., ["tools.okta.list_users", "core.http_request"])
267283

268-
Returns:
269-
Number of scopes inserted
270-
"""
284+
async def _seed_platform_registry_scopes(
285+
session: AsyncSession,
286+
action_keys: list[str],
287+
) -> int:
288+
"""Seed platform registry action scopes in bulk."""
271289
if not action_keys:
272290
return 0
273291

274-
logger.info("Seeding registry scopes", num_actions=len(action_keys))
292+
logger.info(
293+
"Seeding registry scopes",
294+
num_actions=len(action_keys),
295+
)
275296

276297
values = [
277-
{
278-
"id": uuid4(),
279-
"name": f"action:{key}:execute",
280-
"resource": "action",
281-
"action": "execute",
282-
"description": f"Execute {key} action",
283-
"source": ScopeSource.PLATFORM,
284-
"source_ref": key,
285-
"organization_id": None,
286-
}
298+
_build_registry_scope_row(
299+
action_key=key,
300+
source=ScopeSource.PLATFORM,
301+
organization_id=None,
302+
)
287303
for key in action_keys
288304
]
289305

290-
stmt = pg_insert(Scope).values(values)
291-
stmt = stmt.on_conflict_do_nothing(
292-
index_elements=["name"], index_where=Scope.organization_id.is_(None)
306+
return await _upsert_registry_scope_rows(
307+
session=session,
308+
values=values,
309+
source=ScopeSource.PLATFORM,
293310
)
294311

312+
313+
async def _seed_custom_registry_scopes(
314+
session: AsyncSession,
315+
action_keys: list[str],
316+
) -> int:
317+
"""Seed custom registry scopes for all organizations using chunked upserts."""
318+
if not action_keys:
319+
return 0
320+
321+
org_stmt = select(Organization.id)
322+
org_result = await session.execute(org_stmt)
323+
org_ids = [org_id for (org_id,) in org_result.tuples().all()]
324+
if not org_ids:
325+
return 0
326+
327+
logger.info(
328+
"Seeding registry scopes",
329+
num_actions=len(action_keys),
330+
source=ScopeSource.CUSTOM.value,
331+
num_organizations=len(org_ids),
332+
)
333+
334+
inserted_count = 0
335+
batch_values: list[ScopeInsertRow] = []
336+
for org_id in org_ids:
337+
for key in action_keys:
338+
batch_values.append(
339+
_build_registry_scope_row(
340+
action_key=key,
341+
source=ScopeSource.CUSTOM,
342+
organization_id=org_id,
343+
)
344+
)
345+
if len(batch_values) >= _CUSTOM_SCOPE_BATCH_ROWS:
346+
inserted_count += await _upsert_registry_scope_rows(
347+
session=session,
348+
values=batch_values,
349+
source=ScopeSource.CUSTOM,
350+
)
351+
batch_values.clear()
352+
353+
if batch_values:
354+
inserted_count += await _upsert_registry_scope_rows(
355+
session=session,
356+
values=batch_values,
357+
source=ScopeSource.CUSTOM,
358+
)
359+
360+
logger.info(
361+
"Registry scopes seeded",
362+
inserted=inserted_count,
363+
total=len(org_ids) * len(action_keys),
364+
source=ScopeSource.CUSTOM.value,
365+
)
366+
return inserted_count
367+
368+
369+
def _build_registry_scope_row(
370+
*, action_key: str, source: ScopeSource, organization_id: UUID | None
371+
) -> ScopeInsertRow:
372+
"""Build a single scope insert row for a registry action key."""
373+
return {
374+
"id": uuid4(),
375+
"name": f"action:{action_key}:execute",
376+
"resource": "action",
377+
"action": "execute",
378+
"description": f"Execute {action_key} action",
379+
"source": source,
380+
"source_ref": action_key,
381+
"organization_id": organization_id,
382+
}
383+
384+
385+
async def _upsert_registry_scope_rows(
386+
*,
387+
session: AsyncSession,
388+
values: list[ScopeInsertRow],
389+
source: ScopeSource,
390+
) -> int:
391+
"""Insert scope rows with conflict handling for platform vs org-scoped scopes."""
392+
if not values:
393+
return 0
394+
stmt = pg_insert(Scope).values(values)
395+
if source == ScopeSource.PLATFORM:
396+
stmt = stmt.on_conflict_do_nothing(
397+
index_elements=["name"], index_where=Scope.organization_id.is_(None)
398+
)
399+
else:
400+
stmt = stmt.on_conflict_do_nothing(index_elements=["organization_id", "name"])
401+
295402
result = await session.execute(stmt)
296-
inserted_count = result.rowcount if result.rowcount else 0 # pyright: ignore[reportAttributeAccessIssue]
403+
inserted_count = result.rowcount or 0 # pyright: ignore[reportAttributeAccessIssue]
297404

298405
logger.info(
299-
"Registry scopes seeded", inserted=inserted_count, total=len(action_keys)
406+
"Registry scopes seeded",
407+
inserted=inserted_count,
408+
total=len(values),
409+
source=source.value,
300410
)
301411
return inserted_count
302412

0 commit comments

Comments
 (0)