Skip to content

Commit 3f32e6d

Browse files
authored
feat(BA-3490): Implement ScalingGroup Keypair Association Actions (#7655)
1 parent eb8641f commit 3f32e6d

File tree

11 files changed

+444
-4
lines changed

11 files changed

+444
-4
lines changed

changes/7655.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Implement ScalingGroup Keypair Association Actions

src/ai/backend/manager/repositories/scaling_group/creators.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from dataclasses import dataclass, field
55
from typing import Any, Optional, override
66

7+
from ai.backend.common.types import AccessKey
78
from ai.backend.manager.models.scaling_group import (
89
ScalingGroupForDomainRow,
10+
ScalingGroupForKeypairsRow,
911
ScalingGroupOpts,
1012
ScalingGroupRow,
1113
)
@@ -58,3 +60,18 @@ def build_row(self) -> ScalingGroupForDomainRow:
5860
scaling_group=self.scaling_group,
5961
domain=self.domain,
6062
)
63+
64+
65+
@dataclass
66+
class ScalingGroupForKeypairsCreatorSpec(CreatorSpec[ScalingGroupForKeypairsRow]):
67+
"""CreatorSpec for associating a scaling group with a keypair."""
68+
69+
scaling_group: str
70+
access_key: AccessKey
71+
72+
@override
73+
def build_row(self) -> ScalingGroupForKeypairsRow:
74+
return ScalingGroupForKeypairsRow(
75+
scaling_group=self.scaling_group,
76+
access_key=self.access_key,
77+
)

src/ai/backend/manager/repositories/scaling_group/db_source/db_source.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,11 @@
1111
from ai.backend.manager.errors.resource import ScalingGroupNotFound
1212
from ai.backend.manager.models.endpoint import EndpointRow
1313
from ai.backend.manager.models.routing import RoutingRow
14-
from ai.backend.manager.models.scaling_group import ScalingGroupForDomainRow, ScalingGroupRow
14+
from ai.backend.manager.models.scaling_group import (
15+
ScalingGroupForDomainRow,
16+
ScalingGroupForKeypairsRow,
17+
ScalingGroupRow,
18+
)
1519
from ai.backend.manager.models.session import SessionRow
1620
from ai.backend.manager.repositories.base import BatchQuerier, execute_batch_querier
1721
from ai.backend.manager.repositories.base.creator import (
@@ -189,3 +193,37 @@ async def check_scaling_group_domain_association_exists(
189193
)
190194
result = await session.scalar(query)
191195
return (result or 0) > 0
196+
197+
async def associate_scaling_group_with_keypairs(
198+
self,
199+
bulk_creator: BulkCreator[ScalingGroupForKeypairsRow],
200+
) -> None:
201+
"""Associates a scaling group with multiple keypairs."""
202+
async with self._db.begin_session() as session:
203+
await execute_bulk_creator(session, bulk_creator)
204+
205+
async def disassociate_scaling_group_with_keypairs(
206+
self,
207+
purger: BatchPurger[ScalingGroupForKeypairsRow],
208+
) -> None:
209+
"""Disassociates a scaling group from multiple keypairs."""
210+
async with self._db.begin_session() as session:
211+
await execute_batch_purger(session, purger)
212+
213+
async def check_scaling_group_keypair_association_exists(
214+
self,
215+
scaling_group_name: str,
216+
access_key: str,
217+
) -> bool:
218+
"""Checks if a scaling group is associated with a keypair."""
219+
async with self._db.begin_readonly_session() as session:
220+
query = sa.select(
221+
sa.exists().where(
222+
sa.and_(
223+
ScalingGroupForKeypairsRow.scaling_group == scaling_group_name,
224+
ScalingGroupForKeypairsRow.access_key == access_key,
225+
)
226+
)
227+
)
228+
result = await session.execute(query)
229+
return result.scalar() or False

src/ai/backend/manager/repositories/scaling_group/purgers.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55

66
import sqlalchemy as sa
77

8-
from ai.backend.manager.models.scaling_group import ScalingGroupForDomainRow
8+
from ai.backend.common.types import AccessKey
9+
from ai.backend.manager.models.scaling_group import (
10+
ScalingGroupForDomainRow,
11+
ScalingGroupForKeypairsRow,
12+
)
913
from ai.backend.manager.repositories.base.purger import BatchPurger, BatchPurgerSpec
1014

1115

@@ -26,6 +30,23 @@ def build_subquery(self) -> sa.sql.Select[tuple[ScalingGroupForDomainRow]]:
2630
)
2731

2832

33+
@dataclass
34+
class ScalingGroupForKeypairsPurgerSpec(BatchPurgerSpec[ScalingGroupForKeypairsRow]):
35+
"""PurgerSpec for disassociating a scaling group from a keypair."""
36+
37+
scaling_group: str
38+
access_key: AccessKey
39+
40+
@override
41+
def build_subquery(self) -> sa.sql.Select[tuple[ScalingGroupForKeypairsRow]]:
42+
return sa.select(ScalingGroupForKeypairsRow).where(
43+
sa.and_(
44+
ScalingGroupForKeypairsRow.scaling_group == self.scaling_group,
45+
ScalingGroupForKeypairsRow.access_key == self.access_key,
46+
)
47+
)
48+
49+
2950
def create_scaling_group_for_domain_purger(
3051
scaling_group: str,
3152
domain: str,
@@ -38,3 +59,17 @@ def create_scaling_group_for_domain_purger(
3859
),
3960
batch_size=1, # We expect only one row to be deleted
4061
)
62+
63+
64+
def create_scaling_group_for_keypairs_purger(
65+
scaling_group: str,
66+
access_key: AccessKey,
67+
) -> BatchPurger[ScalingGroupForKeypairsRow]:
68+
"""Create a BatchPurger for disassociating a scaling group from a keypair."""
69+
return BatchPurger(
70+
spec=ScalingGroupForKeypairsPurgerSpec(
71+
scaling_group=scaling_group,
72+
access_key=access_key,
73+
),
74+
batch_size=1, # We expect only one row to be deleted
75+
)

src/ai/backend/manager/repositories/scaling_group/repository.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
)
1313
from ai.backend.common.resilience.policies.retry import BackoffStrategy
1414
from ai.backend.manager.data.scaling_group.types import ScalingGroupData, ScalingGroupListResult
15-
from ai.backend.manager.models.scaling_group import ScalingGroupForDomainRow, ScalingGroupRow
15+
from ai.backend.manager.models.scaling_group import (
16+
ScalingGroupForDomainRow,
17+
ScalingGroupForKeypairsRow,
18+
ScalingGroupRow,
19+
)
1620
from ai.backend.manager.repositories.base import BatchQuerier
1721
from ai.backend.manager.repositories.base.creator import BulkCreator, Creator
1822
from ai.backend.manager.repositories.base.purger import BatchPurger, Purger
@@ -114,3 +118,27 @@ async def check_scaling_group_domain_association_exists(
114118
scaling_group=scaling_group,
115119
domain=domain,
116120
)
121+
122+
async def associate_scaling_group_with_keypairs(
123+
self,
124+
bulk_creator: BulkCreator[ScalingGroupForKeypairsRow],
125+
) -> None:
126+
"""Associates a scaling group with multiple keypairs."""
127+
await self._db_source.associate_scaling_group_with_keypairs(bulk_creator)
128+
129+
async def disassociate_scaling_group_with_keypairs(
130+
self,
131+
purger: BatchPurger[ScalingGroupForKeypairsRow],
132+
) -> None:
133+
"""Disassociates a scaling group from multiple keypairs."""
134+
await self._db_source.disassociate_scaling_group_with_keypairs(purger)
135+
136+
async def check_scaling_group_keypair_association_exists(
137+
self,
138+
scaling_group_name: str,
139+
access_key: str,
140+
) -> bool:
141+
"""Checks if a scaling group is associated with a keypair."""
142+
return await self._db_source.check_scaling_group_keypair_association_exists(
143+
scaling_group_name, access_key
144+
)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Optional, override
5+
6+
from ai.backend.manager.actions.action import BaseActionResult
7+
from ai.backend.manager.models.scaling_group import ScalingGroupForKeypairsRow
8+
from ai.backend.manager.repositories.base.creator import BulkCreator
9+
10+
from .base import ScalingGroupAction
11+
12+
13+
@dataclass
14+
class AssociateScalingGroupWithKeypairsAction(ScalingGroupAction):
15+
"""Action to associate a scaling group with multiple keypairs."""
16+
17+
bulk_creator: BulkCreator[ScalingGroupForKeypairsRow]
18+
19+
@override
20+
@classmethod
21+
def operation_type(cls) -> str:
22+
return "associate_with_keypairs"
23+
24+
@override
25+
def entity_id(self) -> Optional[str]:
26+
return None
27+
28+
29+
@dataclass
30+
class AssociateScalingGroupWithKeypairsActionResult(BaseActionResult):
31+
"""Result of associating a scaling group with keypairs."""
32+
33+
@override
34+
def entity_id(self) -> Optional[str]:
35+
return None
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import Optional, override
5+
6+
from ai.backend.manager.actions.action import BaseActionResult
7+
from ai.backend.manager.models.scaling_group import ScalingGroupForKeypairsRow
8+
from ai.backend.manager.repositories.base.purger import BatchPurger
9+
10+
from .base import ScalingGroupAction
11+
12+
13+
@dataclass
14+
class DisassociateScalingGroupWithKeypairsAction(ScalingGroupAction):
15+
"""Action to disassociate a scaling group from multiple keypairs."""
16+
17+
purger: BatchPurger[ScalingGroupForKeypairsRow]
18+
19+
@override
20+
@classmethod
21+
def operation_type(cls) -> str:
22+
return "disassociate_with_keypairs"
23+
24+
@override
25+
def entity_id(self) -> Optional[str]:
26+
return None
27+
28+
29+
@dataclass
30+
class DisassociateScalingGroupWithKeypairsActionResult(BaseActionResult):
31+
"""Result of disassociating a scaling group from keypairs."""
32+
33+
@override
34+
def entity_id(self) -> Optional[str]:
35+
return None

src/ai/backend/manager/services/scaling_group/processors.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
AssociateScalingGroupWithDomainsAction,
88
AssociateScalingGroupWithDomainsActionResult,
99
)
10+
from ai.backend.manager.services.scaling_group.actions.associate_with_keypair import (
11+
AssociateScalingGroupWithKeypairsAction,
12+
AssociateScalingGroupWithKeypairsActionResult,
13+
)
1014
from ai.backend.manager.services.scaling_group.actions.create import (
1115
CreateScalingGroupAction,
1216
CreateScalingGroupActionResult,
@@ -15,6 +19,10 @@
1519
DisassociateScalingGroupWithDomainsAction,
1620
DisassociateScalingGroupWithDomainsActionResult,
1721
)
22+
from ai.backend.manager.services.scaling_group.actions.disassociate_with_keypair import (
23+
DisassociateScalingGroupWithKeypairsAction,
24+
DisassociateScalingGroupWithKeypairsActionResult,
25+
)
1826
from ai.backend.manager.services.scaling_group.actions.list_scaling_groups import (
1927
SearchScalingGroupsAction,
2028
SearchScalingGroupsActionResult,
@@ -43,6 +51,12 @@ class ScalingGroupProcessors(AbstractProcessorPackage):
4351
disassociate_scaling_group_with_domains: ActionProcessor[
4452
DisassociateScalingGroupWithDomainsAction, DisassociateScalingGroupWithDomainsActionResult
4553
]
54+
associate_scaling_group_with_keypairs: ActionProcessor[
55+
AssociateScalingGroupWithKeypairsAction, AssociateScalingGroupWithKeypairsActionResult
56+
]
57+
disassociate_scaling_group_with_keypairs: ActionProcessor[
58+
DisassociateScalingGroupWithKeypairsAction, DisassociateScalingGroupWithKeypairsActionResult
59+
]
4660

4761
def __init__(self, service: ScalingGroupService, action_monitors: list[ActionMonitor]) -> None:
4862
self.create_scaling_group = ActionProcessor(service.create_scaling_group, action_monitors)
@@ -55,6 +69,12 @@ def __init__(self, service: ScalingGroupService, action_monitors: list[ActionMon
5569
self.disassociate_scaling_group_with_domains = ActionProcessor(
5670
service.disassociate_scaling_group_with_domains, action_monitors
5771
)
72+
self.associate_scaling_group_with_keypairs = ActionProcessor(
73+
service.associate_scaling_group_with_keypairs, action_monitors
74+
)
75+
self.disassociate_scaling_group_with_keypairs = ActionProcessor(
76+
service.disassociate_scaling_group_with_keypairs, action_monitors
77+
)
5878

5979
@override
6080
def supported_actions(self) -> list[ActionSpec]:
@@ -65,4 +85,6 @@ def supported_actions(self) -> list[ActionSpec]:
6585
SearchScalingGroupsAction.spec(),
6686
AssociateScalingGroupWithDomainsAction.spec(),
6787
DisassociateScalingGroupWithDomainsAction.spec(),
88+
AssociateScalingGroupWithKeypairsAction.spec(),
89+
DisassociateScalingGroupWithKeypairsAction.spec(),
6890
]

src/ai/backend/manager/services/scaling_group/service.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
AssociateScalingGroupWithDomainsAction,
77
AssociateScalingGroupWithDomainsActionResult,
88
)
9+
from ai.backend.manager.services.scaling_group.actions.associate_with_keypair import (
10+
AssociateScalingGroupWithKeypairsAction,
11+
AssociateScalingGroupWithKeypairsActionResult,
12+
)
913
from ai.backend.manager.services.scaling_group.actions.create import (
1014
CreateScalingGroupAction,
1115
CreateScalingGroupActionResult,
@@ -14,6 +18,10 @@
1418
DisassociateScalingGroupWithDomainsAction,
1519
DisassociateScalingGroupWithDomainsActionResult,
1620
)
21+
from ai.backend.manager.services.scaling_group.actions.disassociate_with_keypair import (
22+
DisassociateScalingGroupWithKeypairsAction,
23+
DisassociateScalingGroupWithKeypairsActionResult,
24+
)
1725
from ai.backend.manager.services.scaling_group.actions.list_scaling_groups import (
1826
SearchScalingGroupsAction,
1927
SearchScalingGroupsActionResult,
@@ -85,3 +93,17 @@ async def disassociate_scaling_group_with_domains(
8593
"""Disassociates a scaling group from multiple domains."""
8694
await self._repository.disassociate_scaling_group_with_domains(action.purger)
8795
return DisassociateScalingGroupWithDomainsActionResult()
96+
97+
async def associate_scaling_group_with_keypairs(
98+
self, action: AssociateScalingGroupWithKeypairsAction
99+
) -> AssociateScalingGroupWithKeypairsActionResult:
100+
"""Associates a scaling group with multiple keypairs."""
101+
await self._repository.associate_scaling_group_with_keypairs(action.bulk_creator)
102+
return AssociateScalingGroupWithKeypairsActionResult()
103+
104+
async def disassociate_scaling_group_with_keypairs(
105+
self, action: DisassociateScalingGroupWithKeypairsAction
106+
) -> DisassociateScalingGroupWithKeypairsActionResult:
107+
"""Disassociates a scaling group from multiple keypairs."""
108+
await self._repository.disassociate_scaling_group_with_keypairs(action.purger)
109+
return DisassociateScalingGroupWithKeypairsActionResult()

0 commit comments

Comments
 (0)