Skip to content
This repository was archived by the owner on Aug 2, 2023. It is now read-only.
1 change: 1 addition & 0 deletions changes/442.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add API to get available resources
109 changes: 109 additions & 0 deletions src/ai/backend/manager/api/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,114 @@ async def check_presets(request: web.Request, params: Any) -> web.Response:
return web.json_response(resp, status=200)


@atomic
@server_status_required(READ_ALLOWED)
@auth_required
@check_api_params(
t.Dict({
t.Key('scaling_group', default=None): t.Null | t.String,
t.Key('group', default='default'): t.String,
}))
async def get_available_resources(request: web.Request, params: Any) -> web.Response:
"""
Returns the list of specific group's available resources.

:param scaling_group: If not None, get available resources of specific scaling group.
:param group: Get available resources of specific group (project) and enumerate them.
"""
root_ctx: RootContext = request.app['_root.context']
try:
access_key = request['keypair']['access_key']
domain_name = request['user']['domain_name']
# TODO: uncomment when we implement scaling group.
# scaling_group = request.query.get('scaling_group')
# assert scaling_group is not None, 'scaling_group parameter is missing.'
except (json.decoder.JSONDecodeError, AssertionError) as e:
raise InvalidAPIParameters(extra_msg=str(e.args[0]))
known_slot_types = await root_ctx.shared_config.get_resource_slots()
resp: MutableMapping[str, Any] = {
'scaling_group_remaining': None,
'scaling_groups': None,
}
log.info('GET_AVAILABLE_RESOURCES (ak:{}, g:{}, sg:{})',
request['keypair']['access_key'], params['group'], params['scaling_group'])

async with root_ctx.db.begin_readonly() as conn:
# Check group resource limit and get group_id.
j = sa.join(
groups, association_groups_users,
association_groups_users.c.group_id == groups.c.id,
)
query = (
sa.select([groups.c.id, groups.c.total_resource_slots])
.select_from(j)
.where(
(association_groups_users.c.user_id == request['user']['uuid']) &
(groups.c.name == params['group']) &
(domains.c.name == domain_name)
)
)
result = await conn.execute(query)
row = result.first()
group_id = row['id']
if group_id is None:
raise InvalidAPIParameters('Unknown user group')

# Prepare per scaling group resource.
sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key)
sgroup_names = [sg.name for sg in sgroups]
if params['scaling_group'] is not None:
if params['scaling_group'] not in sgroup_names:
raise InvalidAPIParameters('Unknown scaling group')
sgroup_names = [params['scaling_group']]
per_sgroup = {
sgname: {
'using': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
'remaining': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}),
} for sgname in sgroup_names
}

# Per scaling group resource using from resource occupying kernels.
query = (
sa.select([kernels.c.occupied_slots, kernels.c.scaling_group])
.select_from(kernels)
.where(
(kernels.c.user_uuid == request['user']['uuid']) &
(kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) &
(kernels.c.scaling_group.in_(sgroup_names))
)
)
async for row in (await conn.stream(query)):
per_sgroup[row['scaling_group']]['using'] += row['occupied_slots']

# Per scaling group resource remaining from agents stats.
sgroup_remaining = ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()})
query = (
sa.select([agents.c.available_slots, agents.c.occupied_slots, agents.c.scaling_group])
.select_from(agents)
.where(
(agents.c.status == AgentStatus.ALIVE) &
(agents.c.scaling_group.in_(sgroup_names))
)
)
agent_slots = []
async for row in (await conn.stream(query)):
remaining = row['available_slots'] - row['occupied_slots']
remaining += ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()})
sgroup_remaining += remaining
agent_slots.append(remaining)
per_sgroup[row['scaling_group']]['remaining'] += remaining

# Take maximum allocatable resources per sgroup.
for sgname, sgfields in per_sgroup.items():
for rtype, slots in sgfields.items():
per_sgroup[sgname][rtype] = slots.to_json() # type: ignore # it's serialization

resp['scaling_group_remaining'] = sgroup_remaining.to_json()
resp['scaling_groups'] = per_sgroup
return web.json_response(resp, status=200)


@server_status_required(READ_ALLOWED)
@superadmin_required
@atomic
Expand Down Expand Up @@ -748,6 +856,7 @@ def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iter
add_route = app.router.add_route
cors.add(add_route('GET', '/presets', list_presets))
cors.add(add_route('POST', '/check-presets', check_presets))
cors.add(add_route('GET', '/get-resources', get_available_resources))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename the URL to /available-resources since the HTTP method name "GET" already indicates its a query API.

cors.add(add_route('POST', '/recalculate-usage', recalculate_usage))
cors.add(add_route('GET', '/usage/month', usage_per_month))
cors.add(add_route('GET', '/usage/period', usage_per_period))
Expand Down