Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.
1 change: 1 addition & 0 deletions changes/442.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add API to get available resources
155 changes: 85 additions & 70 deletions src/ai/backend/manager/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@
from ai.backend.common.types import DefaultForUnspecified, ResourceSlot

from ..models import (
agents, resource_presets,
resource_presets,
domains, groups, kernels, users,
AgentStatus,
association_groups_users,
query_allowed_sgroups,
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
get_row,
get_scaling_groups_resources,
get_group_resource_status,
RESOURCE_USAGE_KERNEL_STATUSES, LIVE_STATUS,
)
from .auth import auth_required, superadmin_required
Expand Down Expand Up @@ -130,21 +129,7 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
keypair_remaining = keypair_limits - keypair_occupied

# Check group resource limit and get group_id.
j = sa.join(
groups, association_groups_users,
association_groups_users.c.group_id == groups.c.id,
)
query = (
sa.select([groups.c.id, groups.c.total_resource_slots])
.select_from(j)
.where(
(association_groups_users.c.user_id == request['user']['uuid']) &
(groups.c.name == params['group']) &
(domains.c.name == domain_name)
)
)
result = await conn.execute(query)
row = result.first()
row = await get_row(conn, request, params, domain_name)
group_id = row['id']
group_resource_slots = row['total_resource_slots']
if group_id is None:
Expand Down Expand Up @@ -178,50 +163,12 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
domain_remaining[slot],
)

# Prepare per scaling group resource.
sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key)
sgroup_names = [sg.name for sg in sgroups]
if params['scaling_group'] is not None:
if params['scaling_group'] not in sgroup_names:
raise InvalidAPIParameters('Unknown scaling group')
sgroup_names = [params['scaling_group']]
per_sgroup = {
sgname: {
'using': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
'remaining': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
} for sgname in sgroup_names
}

# Per scaling group resource using from resource occupying kernels.
query = (
sa.select([kernels.c.occupied_slots, kernels.c.scaling_group])
.select_from(kernels)
.where(
(kernels.c.user_uuid == request['user']['uuid']) &
(kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
(kernels.c.scaling_group.in_(sgroup_names))
)
)
async for row in (await conn.stream(query)):
per_sgroup[row['scaling_group']]['using'] += row['occupied_slots']

# Per scaling group resource remaining from agents stats.
sgroup_remaining = ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()})
query = (
sa.select([agents.c.available_slots, agents.c.occupied_slots, agents.c.scaling_group])
.select_from(agents)
.where(
(agents.c.status == AgentStatus.ALIVE) &
(agents.c.scaling_group.in_(sgroup_names))
)
# Take resources per sgroup.
per_sgroup, sgroup_remaining, agent_slots = await get_scaling_groups_resources(
conn, request, params,
domain_name, group_id,
access_key, known_slot_types
)
agent_slots = []
async for row in (await conn.stream(query)):
remaining = row['available_slots'] - row['occupied_slots']
remaining += ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()})
sgroup_remaining += remaining
agent_slots.append(remaining)
per_sgroup[row['scaling_group']]['remaining'] += remaining

# Take maximum allocatable resources per sgroup.
for sgname, sgfields in per_sgroup.items():
Expand Down Expand Up @@ -255,13 +202,9 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
})

# Return group resource status as NaN if not allowed.
group_resource_visibility = \
await root_ctx.shared_config.get_raw('config/api/resources/group_resource_visibility')
group_resource_visibility = t.ToBool().check(group_resource_visibility)
if not group_resource_visibility:
group_limits = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})
group_occupied = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})
group_remaining = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})
group_limits, group_occupied, group_remaining = await get_group_resource_status(
root_ctx, t, known_slot_types
)

resp['keypair_limits'] = keypair_limits.to_json()
resp['keypair_using'] = keypair_occupied.to_json()
Expand All @@ -274,6 +217,77 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
return web.json_response(resp, status=200)


@atomic
@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
t.Dict({
t.Key('scaling_group', default=None): t.Null | t.String,
t.Key('group', default='default'): t.String,
}))
async def get_available_resources(request: web.Request, params: Any) -> web.Response:
"""
Returns the list of specific group's available resources.

:param scaling_group: If not None, get available resources of specific scaling group.
:param group: Get available resources of specific group (project) and enumerate them.
"""
root_ctx: RootContext = request.app['_root.context']
try:
access_key = request['keypair']['access_key']
domain_name = request['user']['domain_name']
# TODO: uncomment when we implement scaling group.
# scaling_group = request.query.get('scaling_group')
# assert scaling_group is not None, 'scaling_group parameter is missing.'
except (json.decoder.JSONDecodeError, AssertionError) as e:
raise InvalidAPIParameters(extra_msg=str(e.args[0]))
known_slot_types = await root_ctx.shared_config.get_resource_slots()
resp: MutableMapping[str, Any] = {
'scaling_group_remaining': None,
'scaling_groups': None,
}
log.info('GET_AVAILABLE_RESOURCES (ak:{}, g:{}, sg:{})',
request['keypair']['access_key'], params['group'], params['scaling_group'])

async with root_ctx.db.begin_readonly() as conn:
# Check group resource limit and get group_id.
row = await get_row(conn, request, params, domain_name)
group_id = row['id']
group_resource_slots = row['total_resource_slots']
if group_id is None:
raise InvalidAPIParameters('Unknown user group')
group_resource_policy = {
'total_resource_slots': group_resource_slots,
'default_for_unspecified': DefaultForUnspecified.UNLIMITED
}
group_limits = ResourceSlot.from_policy(group_resource_policy, known_slot_types)
group_occupied = await root_ctx.registry.get_group_occupancy(group_id, conn=conn)
group_remaining = group_limits - group_occupied

per_sgroup, sgroup_remaining, _ = await get_scaling_groups_resources(
conn, request, params,
domain_name, group_id,
access_key, known_slot_types
)

# Take maximum allocatable resources per sgroup.
for sgname, sgfields in per_sgroup.items():
for rtype, slots in sgfields.items():
per_sgroup[sgname][rtype] = slots.to_json() # type: ignore # it's serialization

# Return group resource status as NaN if not allowed.
group_limits, group_occupied, group_remaining = await get_group_resource_status(
root_ctx, t, known_slot_types
)

resp['group_limits'] = group_limits.to_json()
resp['group_using'] = group_occupied.to_json()
resp['group_remaining'] = group_remaining.to_json()
resp['scaling_group_remaining'] = sgroup_remaining.to_json()
resp['scaling_groups'] = per_sgroup
return web.json_response(resp, status=200)


@server_status_required(READ_ALLOWED)
@superadmin_required
@atomic
Expand Down Expand Up @@ -748,6 +762,7 @@ def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iter
add_route = app.router.add_route
cors.add(add_route('GET', '/presets', list_presets))
cors.add(add_route('POST', '/check-presets', check_presets))
cors.add(add_route('GET', '/available-resources', get_available_resources))
cors.add(add_route('POST', '/recalculate-usage', recalculate_usage))
cors.add(add_route('GET', '/usage/month', usage_per_month))
cors.add(add_route('GET', '/usage/period', usage_per_period))
Expand Down
106 changes: 105 additions & 1 deletion src/ai/backend/manager/models/resource_preset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,34 @@
import graphene
import sqlalchemy as sa
from sqlalchemy.engine.row import Row
from decimal import Decimal

from ai.backend.common.logging import BraceStyleAdapter
from ai.backend.common.types import ResourceSlot
from ..api.exceptions import (
InvalidAPIParameters,
)
from .base import (
metadata, BigInt, BinarySize, ResourceSlotColumn,
simple_db_mutate,
simple_db_mutate_returning_item,
set_if_set,
batch_result,
)
from .agent import (
agents, AgentStatus,
)
from .user import UserRole

from .group import (
groups,
association_groups_users,
)
from .scaling_group import query_allowed_sgroups
from .kernel import (
AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES,
kernels,
)
from .domain import domains
if TYPE_CHECKING:
from .gql import GraphQueryContext

Expand All @@ -34,6 +50,9 @@
'CreateResourcePreset',
'ModifyResourcePreset',
'DeleteResourcePreset',
'get_row',
'get_scaling_groups_resources',
'get_group_resource_status',
)


Expand Down Expand Up @@ -193,3 +212,88 @@ async def mutate(
.where(resource_presets.c.name == name)
)
return await simple_db_mutate(cls, info.context, delete_query)


async def get_row(conn, request, params, domain_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clarify the name of this function. What is the row? Where is it from?

j = sa.join(
groups, association_groups_users,
association_groups_users.c.group_id == groups.c.id,
)
query = (
sa.select([groups.c.id, groups.c.total_resource_slots])
.select_from(j)
.where(
(association_groups_users.c.user_id == request['user']['uuid']) &
(groups.c.name == params['group']) &
(domains.c.name == domain_name)
)
)
result = await conn.execute(query)
row = result.first()

return row


async def get_scaling_groups_resources(
conn, request, params,
domain_name, group_id,
access_key, known_slot_types
):
# Prepare per scaling group resource.
sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key)
sgroup_names = [sg.name for sg in sgroups]
if params['scaling_group'] is not None:
if params['scaling_group'] not in sgroup_names:
raise InvalidAPIParameters('Unknown scaling group')
sgroup_names = [params['scaling_group']]
per_sgroup = {
sgname: {
'using': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
'remaining': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
} for sgname in sgroup_names
}

# Per scaling group resource using from resource occupying kernels.
query = (
sa.select([kernels.c.occupied_slots, kernels.c.scaling_group])
.select_from(kernels)
.where(
(kernels.c.user_uuid == request['user']['uuid']) &
(kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
(kernels.c.scaling_group.in_(sgroup_names))
)
)
async for row in (await conn.stream(query)):
per_sgroup[row['scaling_group']]['using'] += row['occupied_slots']

# Per scaling group resource remaining from agents stats.
sgroup_remaining = ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()})
query = (
sa.select([agents.c.available_slots, agents.c.occupied_slots, agents.c.scaling_group])
.select_from(agents)
.where(
(agents.c.status == AgentStatus.ALIVE) &
(agents.c.scaling_group.in_(sgroup_names))
)
)
agent_slots = []
async for row in (await conn.stream(query)):
remaining = row['available_slots'] - row['occupied_slots']
remaining += ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()})
sgroup_remaining += remaining
agent_slots.append(remaining)
per_sgroup[row['scaling_group']]['remaining'] += remaining

return per_sgroup, sgroup_remaining, agent_slots


async def get_group_resource_status(root_ctx, t, known_slot_types):
group_resource_visibility = \
await root_ctx.shared_config.get_raw('config/api/resources/group_resource_visibility')
group_resource_visibility = t.ToBool().check(group_resource_visibility)
if not group_resource_visibility:
group_limits = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})
group_occupied = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})
group_remaining = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()})

return group_limits, group_occupied, group_remaining