Skip to content

Commit bf2e8b4

Browse files
authored
refactor(BA-3516): Integrate ModifyScalingGroup action with existing GQL resolver (#7524)
1 parent 9050736 commit bf2e8b4

File tree

3 files changed

+55
-15
lines changed

3 files changed

+55
-15
lines changed

changes/7524.enhance.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Integrate `ModifyScalingGroup` action with existing GQL resolver

src/ai/backend/manager/api/gql_legacy/scaling_group.py

Lines changed: 53 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import graphene_federation
1616
import sqlalchemy as sa
1717
from graphene.types.datetime import DateTime as GQLDateTime
18+
from graphql import Undefined
1819
from sqlalchemy.engine.row import Row
1920
from sqlalchemy.orm import load_only
2021

@@ -36,19 +37,31 @@
3637
from ai.backend.manager.models.user import UserRole
3738
from ai.backend.manager.repositories.base.creator import Creator
3839
from ai.backend.manager.repositories.base.purger import Purger
40+
from ai.backend.manager.repositories.base.updater import Updater
3941
from ai.backend.manager.repositories.scaling_group.creators import ScalingGroupCreatorSpec
42+
from ai.backend.manager.repositories.scaling_group.updaters import (
43+
ScalingGroupDriverConfigUpdaterSpec,
44+
ScalingGroupMetadataUpdaterSpec,
45+
ScalingGroupNetworkConfigUpdaterSpec,
46+
ScalingGroupSchedulerConfigUpdaterSpec,
47+
ScalingGroupStatusUpdaterSpec,
48+
ScalingGroupUpdaterSpec,
49+
)
4050
from ai.backend.manager.services.scaling_group.actions.create import (
4151
CreateScalingGroupAction,
4252
)
53+
from ai.backend.manager.services.scaling_group.actions.modify import (
54+
ModifyScalingGroupAction,
55+
)
4356
from ai.backend.manager.services.scaling_group.actions.purge_scaling_group import (
4457
PurgeScalingGroupAction,
4558
)
59+
from ai.backend.manager.types import OptionalState, TriState
4660

4761
from .base import (
4862
batch_multiresult,
4963
batch_multiresult_in_scalar_stream,
5064
batch_result,
51-
set_if_set,
5265
simple_db_mutate,
5366
)
5467
from .gql_relay import (
@@ -606,6 +619,41 @@ class ModifyScalingGroupInput(graphene.InputObjectType):
606619
scheduler_opts = graphene.JSONString(required=False)
607620
use_host_network = graphene.Boolean(required=False)
608621

622+
def to_updater(self, name: str) -> Updater[ScalingGroupRow]:
623+
"""Convert GraphQL input to Updater for scaling group modification."""
624+
status_spec = ScalingGroupStatusUpdaterSpec(
625+
is_active=OptionalState.from_graphql(self.is_active),
626+
is_public=OptionalState.from_graphql(self.is_public),
627+
)
628+
metadata_spec = ScalingGroupMetadataUpdaterSpec(
629+
description=TriState.from_graphql(self.description),
630+
)
631+
network_spec = ScalingGroupNetworkConfigUpdaterSpec(
632+
wsproxy_addr=TriState.from_graphql(self.wsproxy_addr),
633+
wsproxy_api_token=TriState.from_graphql(self.wsproxy_api_token),
634+
use_host_network=OptionalState.from_graphql(self.use_host_network),
635+
)
636+
driver_spec = ScalingGroupDriverConfigUpdaterSpec(
637+
driver=OptionalState.from_graphql(self.driver),
638+
driver_opts=OptionalState.from_graphql(self.driver_opts),
639+
)
640+
scheduler_spec = ScalingGroupSchedulerConfigUpdaterSpec(
641+
scheduler=OptionalState.from_graphql(self.scheduler),
642+
scheduler_opts=OptionalState.from_graphql(
643+
ScalingGroupOpts.from_json(self.scheduler_opts)
644+
if self.scheduler_opts is not None and self.scheduler_opts is not Undefined
645+
else Undefined
646+
),
647+
)
648+
spec = ScalingGroupUpdaterSpec(
649+
status=status_spec,
650+
metadata=metadata_spec,
651+
network=network_spec,
652+
driver=driver_spec,
653+
scheduler=scheduler_spec,
654+
)
655+
return Updater(spec=spec, pk_value=name)
656+
609657

610658
class CreateScalingGroup(graphene.Mutation):
611659
allowed_roles = (UserRole.SUPERADMIN,)
@@ -683,21 +731,11 @@ async def mutate(
683731
name: str,
684732
props: ModifyScalingGroupInput,
685733
) -> ModifyScalingGroup:
686-
data: dict[str, Any] = {}
687-
set_if_set(props, data, "description")
688-
set_if_set(props, data, "is_active")
689-
set_if_set(props, data, "is_public")
690-
set_if_set(props, data, "wsproxy_addr")
691-
set_if_set(props, data, "wsproxy_api_token")
692-
set_if_set(props, data, "driver")
693-
set_if_set(props, data, "driver_opts")
694-
set_if_set(props, data, "scheduler")
695-
set_if_set(
696-
props, data, "scheduler_opts", clean_func=lambda v: ScalingGroupOpts.from_json(v)
734+
graph_ctx: GraphQueryContext = info.context
735+
await graph_ctx.processors.scaling_group.modify_scaling_group.wait_for_complete(
736+
ModifyScalingGroupAction(updater=props.to_updater(name))
697737
)
698-
set_if_set(props, data, "use_host_network")
699-
update_query = sa.update(scaling_groups).values(data).where(scaling_groups.c.name == name)
700-
return await simple_db_mutate(cls, info.context, update_query)
738+
return cls(ok=True, msg="success")
701739

702740

703741
class DeleteScalingGroup(graphene.Mutation):

src/ai/backend/manager/repositories/scaling_group/repository.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ async def purge_scaling_group(
8484
"""
8585
return await self._db_source.purge_scaling_group(purger)
8686

87+
@scaling_group_repository_resilience.apply()
8788
async def update_scaling_group(
8889
self,
8990
updater: Updater[ScalingGroupRow],

0 commit comments

Comments
 (0)