Skip to content

Commit 4ad1eeb

Browse files
committed
feat: Add check_scaling_group_user_group_association_exists into repository layer
1 parent 5deaa53 commit 4ad1eeb

File tree

3 files changed

+62
-30
lines changed

3 files changed

+62
-30
lines changed

src/ai/backend/manager/repositories/scaling_group/db_source/db_source.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import uuid
56
from typing import TYPE_CHECKING, cast
67

78
import sqlalchemy as sa
@@ -209,3 +210,23 @@ async def disassociate_scaling_group_with_user_group(
209210
"""Disassociates a single scaling group from a user group (project)."""
210211
async with self._db.begin_session() as session:
211212
await execute_batch_purger(session, purger)
213+
214+
async def check_scaling_group_user_group_association_exists(
215+
self,
216+
scaling_group: str,
217+
user_group: uuid.UUID,
218+
) -> bool:
219+
"""Checks if a scaling group is associated with a user group (project)."""
220+
async with self._db.begin_readonly_session() as session:
221+
query = (
222+
sa.select(sa.func.count())
223+
.select_from(ScalingGroupForProjectRow)
224+
.where(
225+
sa.and_(
226+
ScalingGroupForProjectRow.scaling_group == scaling_group,
227+
ScalingGroupForProjectRow.group == user_group,
228+
)
229+
)
230+
)
231+
result = await session.scalar(query)
232+
return (result or 0) > 0

src/ai/backend/manager/repositories/scaling_group/repository.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import uuid
34
from typing import TYPE_CHECKING
45

56
from ai.backend.common.metrics.metric import DomainType, LayerType
@@ -132,3 +133,14 @@ async def disassociate_scaling_group_with_user_group(
132133
) -> None:
133134
"""Disassociates a single scaling group from a user group (project)."""
134135
await self._db_source.disassociate_scaling_group_with_user_group(purger)
136+
137+
async def check_scaling_group_user_group_association_exists(
138+
self,
139+
scaling_group: str,
140+
user_group: uuid.UUID,
141+
) -> bool:
142+
"""Checks if a scaling group is associated with a user group (project)."""
143+
return await self._db_source.check_scaling_group_user_group_association_exists(
144+
scaling_group=scaling_group,
145+
user_group=user_group,
146+
)

tests/unit/manager/repositories/scaling_group/test_scaling_group_repository.py

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,6 @@ async def test_associate_scaling_group_with_user_group_success(
946946
scaling_group_repository: ScalingGroupRepository,
947947
sample_scaling_group_for_purge: str,
948948
test_user_domain_group: tuple[uuid.UUID, str, uuid.UUID],
949-
db_with_cleanup: ExtendedAsyncSAEngine,
950949
) -> None:
951950
"""Test associating a scaling group with a user group (project)."""
952951
# Given: A scaling group and a project (group)
@@ -962,40 +961,43 @@ async def test_associate_scaling_group_with_user_group_success(
962961
)
963962
await scaling_group_repository.associate_scaling_group_with_user_group(creator)
964963

965-
# Then: Association should exist in the database
966-
async with db_with_cleanup.begin_readonly_session() as db_sess:
967-
query = sa.select(ScalingGroupForProjectRow).where(
968-
sa.and_(
969-
ScalingGroupForProjectRow.scaling_group == sgroup_name,
970-
ScalingGroupForProjectRow.group == project_id,
971-
)
964+
# Then: Association should exist
965+
association_exists = (
966+
await scaling_group_repository.check_scaling_group_user_group_association_exists(
967+
scaling_group=sgroup_name,
968+
user_group=project_id,
972969
)
973-
result = await db_sess.execute(query)
974-
row = result.scalar_one_or_none()
975-
assert row is not None
976-
assert row.scaling_group == sgroup_name
977-
assert row.group == project_id
970+
)
971+
assert association_exists is True
978972

979973
async def test_disassociate_scaling_group_with_user_group_success(
980974
self,
981975
scaling_group_repository: ScalingGroupRepository,
982976
sample_scaling_group_for_purge: str,
983977
test_user_domain_group: tuple[uuid.UUID, str, uuid.UUID],
984-
db_with_cleanup: ExtendedAsyncSAEngine,
985978
) -> None:
986979
"""Test disassociating a scaling group from a user group (project)."""
987980
# Given: A scaling group associated with a project
988981
sgroup_name = sample_scaling_group_for_purge
989982
_, _, project_id = test_user_domain_group
990983

991-
# First, associate the scaling group with the project
992-
async with db_with_cleanup.begin_session() as db_sess:
993-
association = ScalingGroupForProjectRow(
984+
# First, associate the scaling group with the project using repository
985+
creator = Creator(
986+
spec=ScalingGroupForProjectCreatorSpec(
994987
scaling_group=sgroup_name,
995-
group=project_id,
988+
project=project_id,
996989
)
997-
db_sess.add(association)
998-
await db_sess.flush()
990+
)
991+
await scaling_group_repository.associate_scaling_group_with_user_group(creator)
992+
993+
# Verify association exists
994+
association_exists = (
995+
await scaling_group_repository.check_scaling_group_user_group_association_exists(
996+
scaling_group=sgroup_name,
997+
user_group=project_id,
998+
)
999+
)
1000+
assert association_exists is True
9991001

10001002
# When: Disassociate the scaling group from the project
10011003
purger = create_scaling_group_for_project_purger(
@@ -1004,17 +1006,14 @@ async def test_disassociate_scaling_group_with_user_group_success(
10041006
)
10051007
await scaling_group_repository.disassociate_scaling_group_with_user_group(purger)
10061008

1007-
# Then: Association should no longer exist in the database
1008-
async with db_with_cleanup.begin_readonly_session() as db_sess:
1009-
query = sa.select(ScalingGroupForProjectRow).where(
1010-
sa.and_(
1011-
ScalingGroupForProjectRow.scaling_group == sgroup_name,
1012-
ScalingGroupForProjectRow.group == project_id,
1013-
)
1009+
# Then: Association should no longer exist
1010+
association_exists = (
1011+
await scaling_group_repository.check_scaling_group_user_group_association_exists(
1012+
scaling_group=sgroup_name,
1013+
user_group=project_id,
10141014
)
1015-
result = await db_sess.execute(query)
1016-
row = result.scalar_one_or_none()
1017-
assert row is None
1015+
)
1016+
assert association_exists is False
10181017

10191018
async def test_disassociate_nonexistent_scaling_group_with_user_group(
10201019
self,

0 commit comments

Comments
 (0)