Skip to content

Commit 521c5c5

Browse files
committed
feat: RBAC entity creator, purger and inviter
1 parent d5fc84b commit 521c5c5

File tree

5 files changed

+348
-5
lines changed

5 files changed

+348
-5
lines changed

src/ai/backend/manager/repositories/base/creator.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,27 @@
44

55
from abc import ABC, abstractmethod
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Generic, TypeVar
7+
from typing import Generic, TypeVar
88

9-
from ai.backend.manager.models.base import Base
9+
from sqlalchemy.ext.asyncio import AsyncSession as SASession
1010

11-
if TYPE_CHECKING:
12-
from sqlalchemy.ext.asyncio import AsyncSession as SASession
11+
from ai.backend.manager.data.permission.id import (
12+
ObjectId,
13+
ScopeId,
14+
)
1315

14-
TRow = TypeVar("TRow", bound=Base)
16+
17+
class RBACEntityRow(ABC):
18+
@abstractmethod
19+
def parsed_scope_id(self) -> ScopeId:
20+
pass
21+
22+
@abstractmethod
23+
def parsed_object_id(self) -> ObjectId:
24+
pass
25+
26+
27+
TRow = TypeVar("TRow", bound=RBACEntityRow)
1528

1629

1730
class CreatorSpec(ABC, Generic[TRow]):

src/ai/backend/manager/repositories/base/rbac_entity/__init__.py

Whitespace-only changes.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
5+
from typing import Generic, TypeVar
6+
7+
from sqlalchemy.ext.asyncio import AsyncSession as SASession
8+
9+
from ai.backend.manager.data.permission.id import (
10+
ObjectId,
11+
ScopeId,
12+
)
13+
from ai.backend.manager.models.rbac_models.association_scopes_entities import (
14+
AssociationScopesEntitiesRow,
15+
)
16+
17+
18+
class RBACEntityRow(ABC):
19+
@abstractmethod
20+
def parsed_object_id(self) -> ObjectId:
21+
pass
22+
23+
24+
TEntityRow = TypeVar("TEntityRow", bound=RBACEntityRow)
25+
26+
27+
class CreatorSpec(ABC, Generic[TEntityRow]):
28+
"""Abstract base class defining a row to insert.
29+
30+
Implementations specify what to create by providing:
31+
- A build_row() method that returns the ORM instance to insert
32+
"""
33+
34+
@abstractmethod
35+
def build_row(self) -> TEntityRow:
36+
"""Build ORM row instance to insert.
37+
38+
Returns:
39+
An ORM model instance to be inserted
40+
"""
41+
raise NotImplementedError
42+
43+
44+
@dataclass
45+
class Creator(Generic[TEntityRow]):
46+
"""Bundles RBAC-aware creator spec for insert operations.
47+
48+
Attributes:
49+
spec: CreatorSpec implementation defining what to create.
50+
rbac_context: RBAC context for the creation operation.
51+
"""
52+
53+
spec: CreatorSpec[TEntityRow]
54+
scope_id: ScopeId
55+
56+
57+
@dataclass
58+
class CreatorResult(Generic[TEntityRow]):
59+
"""Result of executing a create operation."""
60+
61+
row: TEntityRow
62+
63+
64+
async def execute_creator(
65+
db_sess: SASession,
66+
creator: Creator[TEntityRow],
67+
) -> CreatorResult[TEntityRow]:
68+
"""Execute INSERT with RBAC-aware creator.
69+
70+
Args:
71+
db_sess: Async SQLAlchemy session.
72+
creator: Creator instance with RBAC context and spec.
73+
74+
Returns:
75+
Result of the create operation.
76+
"""
77+
row = creator.spec.build_row()
78+
db_sess.add(row)
79+
await db_sess.flush()
80+
await db_sess.refresh(row)
81+
scope_id = creator.scope_id
82+
object_id = row.parsed_object_id()
83+
db_sess.add(
84+
AssociationScopesEntitiesRow(
85+
scope_type=scope_id.scope_type,
86+
scope_id=scope_id.scope_id,
87+
entity_type=object_id.entity_type,
88+
entity_id=object_id.entity_id,
89+
)
90+
)
91+
return CreatorResult(row=row)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
import sqlalchemy as sa
6+
from sqlalchemy.ext.asyncio import AsyncSession as SASession
7+
8+
from ai.backend.common.data.permission.types import (
9+
OperationType,
10+
)
11+
from ai.backend.manager.data.permission.id import (
12+
ObjectId,
13+
ScopeId,
14+
)
15+
from ai.backend.manager.data.permission.types import RoleSource
16+
from ai.backend.manager.models.rbac_models.association_scopes_entities import (
17+
AssociationScopesEntitiesRow,
18+
)
19+
from ai.backend.manager.models.rbac_models.permission.object_permission import ObjectPermissionRow
20+
from ai.backend.manager.models.rbac_models.permission.permission_group import PermissionGroupRow
21+
from ai.backend.manager.models.rbac_models.role import RoleRow
22+
23+
24+
@dataclass
25+
class Inviter:
26+
entity_id: ObjectId
27+
scope_id: ScopeId
28+
operations: list[OperationType]
29+
30+
31+
async def execute_inviter(
32+
db_sess: SASession,
33+
inviter: Inviter,
34+
) -> None:
35+
scope_id = inviter.scope_id
36+
object_id = inviter.entity_id
37+
38+
scope_system_roles = await db_sess.scalars(
39+
sa.select(RoleRow)
40+
.select_from(sa.join(RoleRow, PermissionGroupRow, RoleRow.id == PermissionGroupRow.role_id))
41+
.where(
42+
sa.and_(
43+
RoleRow.source == RoleSource.SYSTEM,
44+
PermissionGroupRow.scope_id == scope_id.scope_id,
45+
PermissionGroupRow.scope_type == scope_id.scope_type,
46+
)
47+
)
48+
)
49+
role_ids = [role.id for role in scope_system_roles.all()]
50+
for role_id in role_ids:
51+
for operation in inviter.operations:
52+
db_sess.add(
53+
ObjectPermissionRow(
54+
role_id=role_id,
55+
entity_type=object_id.entity_type,
56+
entity_id=object_id.entity_id,
57+
operation=operation,
58+
)
59+
)
60+
db_sess.add(
61+
AssociationScopesEntitiesRow(
62+
scope_type=scope_id.scope_type,
63+
scope_id=scope_id.scope_id,
64+
entity_type=object_id.entity_type,
65+
entity_id=object_id.entity_id,
66+
)
67+
)
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterable
4+
from dataclasses import dataclass
5+
from typing import Generic, TypeVar
6+
from uuid import UUID
7+
8+
import sqlalchemy as sa
9+
from sqlalchemy.ext.asyncio import AsyncSession as SASession
10+
from sqlalchemy.orm import (
11+
contains_eager,
12+
selectinload,
13+
with_loader_criteria,
14+
)
15+
16+
from ai.backend.manager.data.permission.id import (
17+
ObjectId,
18+
ScopeId,
19+
)
20+
from ai.backend.manager.errors.repository import UnsupportedCompositePrimaryKeyError
21+
from ai.backend.manager.models.base import Base
22+
from ai.backend.manager.models.rbac_models.association_scopes_entities import (
23+
AssociationScopesEntitiesRow,
24+
)
25+
from ai.backend.manager.models.rbac_models.permission.object_permission import ObjectPermissionRow
26+
from ai.backend.manager.models.rbac_models.permission.permission import PermissionRow
27+
from ai.backend.manager.models.rbac_models.permission.permission_group import PermissionGroupRow
28+
from ai.backend.manager.models.rbac_models.role import RoleRow
29+
30+
TRow = TypeVar("TRow", bound=Base)
31+
32+
33+
@dataclass
34+
class Purger(Generic[TRow]):
35+
"""Single-row delete by primary key.
36+
37+
Attributes:
38+
row_class: ORM class for table access and PK detection.
39+
pk_value: Primary key value to identify the target row.
40+
"""
41+
42+
row_class: type[TRow]
43+
pk_value: UUID | str | int
44+
entity_id: ObjectId
45+
46+
47+
@dataclass
48+
class PurgerResult(Generic[TRow]):
49+
"""Result of executing a single-row delete operation."""
50+
51+
row: TRow
52+
53+
54+
async def _get_association_rows(
55+
db_sess: SASession,
56+
object_id: ObjectId,
57+
) -> list[AssociationScopesEntitiesRow]:
58+
assoc_scalars = await db_sess.scalars(
59+
sa.select(AssociationScopesEntitiesRow).where(
60+
sa.and_(
61+
AssociationScopesEntitiesRow.entity_id == object_id.entity_id,
62+
AssociationScopesEntitiesRow.entity_type == object_id.entity_type,
63+
)
64+
)
65+
)
66+
return assoc_scalars.all()
67+
68+
69+
async def _get_related_roles(
70+
db_sess: SASession,
71+
object_id: ObjectId,
72+
scopes: list[ScopeId],
73+
) -> list[RoleRow]:
74+
role_scalars = await db_sess.scalars(
75+
sa.select(RoleRow)
76+
.select_from(
77+
sa.join(RoleRow, ObjectPermissionRow, RoleRow.id == ObjectPermissionRow.role_id)
78+
)
79+
.where(
80+
sa.and_(
81+
ObjectPermissionRow.entity_id == object_id.entity_id,
82+
ObjectPermissionRow.entity_type == object_id.entity_type,
83+
)
84+
)
85+
.options(
86+
contains_eager(RoleRow.object_permission_rows),
87+
selectinload(RoleRow.permission_group_rows),
88+
with_loader_criteria(
89+
PermissionGroupRow,
90+
sa.and_(
91+
sa.not_(
92+
sa.exists(
93+
sa.select(PermissionRow.id).where(
94+
PermissionRow.permission_group_id == PermissionGroupRow.id
95+
)
96+
)
97+
),
98+
PermissionGroupRow.scope_id.in_([scope.scope_id for scope in scopes]), # type: ignore[attr-defined]
99+
PermissionGroupRow.scope_type.in_([scope.scope_type for scope in scopes]), # type: ignore[attr-defined]
100+
),
101+
),
102+
)
103+
)
104+
return role_scalars.all()
105+
106+
107+
async def _purge_related_rows(
108+
db_sess: SASession,
109+
object_permission_ids: Iterable[UUID],
110+
permission_group_ids: Iterable[UUID],
111+
association_ids: Iterable[UUID],
112+
) -> None:
113+
await db_sess.execute(
114+
sa.delete(ObjectPermissionRow).where(ObjectPermissionRow.id.in_(object_permission_ids)) # type: ignore[attr-defined]
115+
)
116+
await db_sess.execute(
117+
sa.delete(PermissionGroupRow).where(PermissionGroupRow.id.in_(permission_group_ids)) # type: ignore[attr-defined]
118+
)
119+
await db_sess.execute(
120+
sa.delete(AssociationScopesEntitiesRow).where(
121+
AssociationScopesEntitiesRow.id.in_(association_ids) # type: ignore[attr-defined]
122+
)
123+
)
124+
125+
126+
async def execute_purger(
127+
db_sess: SASession,
128+
purger: Purger[TRow],
129+
) -> PurgerResult[TRow] | None:
130+
row_class = purger.row_class
131+
table = row_class.__table__ # type: ignore[attr-defined]
132+
pk_columns = list(table.primary_key.columns)
133+
134+
if len(pk_columns) != 1:
135+
raise UnsupportedCompositePrimaryKeyError(
136+
f"Purger only supports single-column primary keys (table: {table.name})",
137+
)
138+
139+
scopes: list[ScopeId] = []
140+
object_id = purger.entity_id
141+
object_permission_ids: list[UUID] = []
142+
permission_group_ids: list[UUID] = []
143+
association_ids: list[UUID] = []
144+
145+
assoc_rows = await _get_association_rows(db_sess, object_id)
146+
for assoc_row in assoc_rows:
147+
association_ids.append(assoc_row.id)
148+
scopes.append(assoc_row.parsed_scope_id())
149+
150+
# Check all roles associated with the entity as object permission
151+
role_rows = await _get_related_roles(db_sess, object_id, scopes)
152+
for role_row in role_rows:
153+
for obj_perm_row in role_row.object_permission_rows:
154+
object_permission_ids.append(obj_perm_row.id)
155+
for perm_group_row in role_row.permission_group_rows:
156+
permission_group_ids.append(perm_group_row.id)
157+
await _purge_related_rows(
158+
db_sess,
159+
object_permission_ids,
160+
permission_group_ids,
161+
association_ids,
162+
)
163+
stmt = sa.delete(table).where(pk_columns[0] == purger.pk_value).returning(*table.columns)
164+
165+
result = await db_sess.execute(stmt)
166+
row_data = result.fetchone()
167+
168+
if row_data is None:
169+
return None
170+
171+
deleted_row: TRow = row_class(**dict(row_data._mapping))
172+
return PurgerResult(row=deleted_row)

0 commit comments

Comments
 (0)