|
15 | 15 | import graphene_federation |
16 | 16 | import sqlalchemy as sa |
17 | 17 | from graphene.types.datetime import DateTime as GQLDateTime |
| 18 | +from graphql import Undefined |
18 | 19 | from sqlalchemy.engine.row import Row |
19 | 20 | from sqlalchemy.orm import load_only |
20 | 21 |
|
|
36 | 37 | from ai.backend.manager.models.user import UserRole |
37 | 38 | from ai.backend.manager.repositories.base.creator import Creator |
38 | 39 | from ai.backend.manager.repositories.base.purger import Purger |
| 40 | +from ai.backend.manager.repositories.base.updater import Updater |
39 | 41 | from ai.backend.manager.repositories.scaling_group.creators import ScalingGroupCreatorSpec |
| 42 | +from ai.backend.manager.repositories.scaling_group.updaters import ( |
| 43 | + ScalingGroupDriverConfigUpdaterSpec, |
| 44 | + ScalingGroupMetadataUpdaterSpec, |
| 45 | + ScalingGroupNetworkConfigUpdaterSpec, |
| 46 | + ScalingGroupSchedulerConfigUpdaterSpec, |
| 47 | + ScalingGroupStatusUpdaterSpec, |
| 48 | + ScalingGroupUpdaterSpec, |
| 49 | +) |
40 | 50 | from ai.backend.manager.services.scaling_group.actions.create import ( |
41 | 51 | CreateScalingGroupAction, |
42 | 52 | ) |
| 53 | +from ai.backend.manager.services.scaling_group.actions.modify import ( |
| 54 | + ModifyScalingGroupAction, |
| 55 | +) |
43 | 56 | from ai.backend.manager.services.scaling_group.actions.purge_scaling_group import ( |
44 | 57 | PurgeScalingGroupAction, |
45 | 58 | ) |
| 59 | +from ai.backend.manager.types import OptionalState, TriState |
46 | 60 |
|
47 | 61 | from .base import ( |
48 | 62 | batch_multiresult, |
49 | 63 | batch_multiresult_in_scalar_stream, |
50 | 64 | batch_result, |
51 | | - set_if_set, |
52 | 65 | simple_db_mutate, |
53 | 66 | ) |
54 | 67 | from .gql_relay import ( |
@@ -606,6 +619,41 @@ class ModifyScalingGroupInput(graphene.InputObjectType): |
606 | 619 | scheduler_opts = graphene.JSONString(required=False) |
607 | 620 | use_host_network = graphene.Boolean(required=False) |
608 | 621 |
|
| 622 | + def to_updater(self, name: str) -> Updater[ScalingGroupRow]: |
| 623 | + """Convert GraphQL input to Updater for scaling group modification.""" |
| 624 | + status_spec = ScalingGroupStatusUpdaterSpec( |
| 625 | + is_active=OptionalState.from_graphql(self.is_active), |
| 626 | + is_public=OptionalState.from_graphql(self.is_public), |
| 627 | + ) |
| 628 | + metadata_spec = ScalingGroupMetadataUpdaterSpec( |
| 629 | + description=TriState.from_graphql(self.description), |
| 630 | + ) |
| 631 | + network_spec = ScalingGroupNetworkConfigUpdaterSpec( |
| 632 | + wsproxy_addr=TriState.from_graphql(self.wsproxy_addr), |
| 633 | + wsproxy_api_token=TriState.from_graphql(self.wsproxy_api_token), |
| 634 | + use_host_network=OptionalState.from_graphql(self.use_host_network), |
| 635 | + ) |
| 636 | + driver_spec = ScalingGroupDriverConfigUpdaterSpec( |
| 637 | + driver=OptionalState.from_graphql(self.driver), |
| 638 | + driver_opts=OptionalState.from_graphql(self.driver_opts), |
| 639 | + ) |
| 640 | + scheduler_spec = ScalingGroupSchedulerConfigUpdaterSpec( |
| 641 | + scheduler=OptionalState.from_graphql(self.scheduler), |
| 642 | + scheduler_opts=OptionalState.from_graphql( |
| 643 | + ScalingGroupOpts.from_json(self.scheduler_opts) |
| 644 | + if self.scheduler_opts is not None and self.scheduler_opts is not Undefined |
| 645 | + else Undefined |
| 646 | + ), |
| 647 | + ) |
| 648 | + spec = ScalingGroupUpdaterSpec( |
| 649 | + status=status_spec, |
| 650 | + metadata=metadata_spec, |
| 651 | + network=network_spec, |
| 652 | + driver=driver_spec, |
| 653 | + scheduler=scheduler_spec, |
| 654 | + ) |
| 655 | + return Updater(spec=spec, pk_value=name) |
| 656 | + |
609 | 657 |
|
610 | 658 | class CreateScalingGroup(graphene.Mutation): |
611 | 659 | allowed_roles = (UserRole.SUPERADMIN,) |
@@ -683,21 +731,11 @@ async def mutate( |
683 | 731 | name: str, |
684 | 732 | props: ModifyScalingGroupInput, |
685 | 733 | ) -> ModifyScalingGroup: |
686 | | - data: dict[str, Any] = {} |
687 | | - set_if_set(props, data, "description") |
688 | | - set_if_set(props, data, "is_active") |
689 | | - set_if_set(props, data, "is_public") |
690 | | - set_if_set(props, data, "wsproxy_addr") |
691 | | - set_if_set(props, data, "wsproxy_api_token") |
692 | | - set_if_set(props, data, "driver") |
693 | | - set_if_set(props, data, "driver_opts") |
694 | | - set_if_set(props, data, "scheduler") |
695 | | - set_if_set( |
696 | | - props, data, "scheduler_opts", clean_func=lambda v: ScalingGroupOpts.from_json(v) |
| 734 | + graph_ctx: GraphQueryContext = info.context |
| 735 | + await graph_ctx.processors.scaling_group.modify_scaling_group.wait_for_complete( |
| 736 | + ModifyScalingGroupAction(updater=props.to_updater(name)) |
697 | 737 | ) |
698 | | - set_if_set(props, data, "use_host_network") |
699 | | - update_query = sa.update(scaling_groups).values(data).where(scaling_groups.c.name == name) |
700 | | - return await simple_db_mutate(cls, info.context, update_query) |
| 738 | + return cls(ok=True, msg="success") |
701 | 739 |
|
702 | 740 |
|
703 | 741 | class DeleteScalingGroup(graphene.Mutation): |
|
0 commit comments