diff --git a/changes/442.feature b/changes/442.feature new file mode 100644 index 000000000..75c4d29c3 --- /dev/null +++ b/changes/442.feature @@ -0,0 +1 @@ +Add API to get available resources \ No newline at end of file diff --git a/src/ai/backend/manager/api/resource.py b/src/ai/backend/manager/api/resource.py index 040ff48e7..0bcbb0c87 100644 --- a/src/ai/backend/manager/api/resource.py +++ b/src/ai/backend/manager/api/resource.py @@ -34,12 +34,11 @@ from ai.backend.common.types import DefaultForUnspecified, ResourceSlot from ..models import ( - agents, resource_presets, + resource_presets, domains, groups, kernels, users, - AgentStatus, - association_groups_users, - query_allowed_sgroups, - AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, + get_groups_info_by_row, + check_scaling_group_resource, + check_group_resource, RESOURCE_USAGE_KERNEL_STATUSES, LIVE_STATUS, ) from .auth import auth_required, superadmin_required @@ -64,7 +63,7 @@ async def list_presets(request: web.Request) -> web.Response: """ Returns the list of all resource presets. """ - log.info('LIST_PRESETS (ak:{})', request['keypair']['access_key']) + log.info('RESOURCE.LIST_PRESETS (ak:{})', request['keypair']['access_key']) root_ctx: RootContext = request.app['_root.context'] await root_ctx.shared_config.get_resource_slots() async with root_ctx.db.begin_readonly() as conn: @@ -120,7 +119,7 @@ async def check_presets(request: web.Request, params: Any) -> web.Response: 'scaling_groups': None, 'presets': [], } - log.info('CHECK_PRESETS (ak:{}, g:{}, sg:{})', + log.info('RESOURCE.CHECK_PRESETS (ak:{}, g:{}, sg:{})', request['keypair']['access_key'], params['group'], params['scaling_group']) async with root_ctx.db.begin_readonly() as conn: @@ -130,21 +129,7 @@ async def check_presets(request: web.Request, params: Any) -> web.Response: keypair_remaining = keypair_limits - keypair_occupied # Check group resource limit and get group_id. - j = sa.join( - groups, association_groups_users, - association_groups_users.c.group_id == groups.c.id, - ) - query = ( - sa.select([groups.c.id, groups.c.total_resource_slots]) - .select_from(j) - .where( - (association_groups_users.c.user_id == request['user']['uuid']) & - (groups.c.name == params['group']) & - (domains.c.name == domain_name) - ) - ) - result = await conn.execute(query) - row = result.first() + row = await get_groups_info_by_row(conn, request, params, domain_name) group_id = row['id'] group_resource_slots = row['total_resource_slots'] if group_id is None: @@ -178,50 +163,12 @@ async def check_presets(request: web.Request, params: Any) -> web.Response: domain_remaining[slot], ) - # Prepare per scaling group resource. - sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key) - sgroup_names = [sg.name for sg in sgroups] - if params['scaling_group'] is not None: - if params['scaling_group'] not in sgroup_names: - raise InvalidAPIParameters('Unknown scaling group') - sgroup_names = [params['scaling_group']] - per_sgroup = { - sgname: { - 'using': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}), - 'remaining': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}), - } for sgname in sgroup_names - } - - # Per scaling group resource using from resource occupying kernels. - query = ( - sa.select([kernels.c.occupied_slots, kernels.c.scaling_group]) - .select_from(kernels) - .where( - (kernels.c.user_uuid == request['user']['uuid']) & - (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & - (kernels.c.scaling_group.in_(sgroup_names)) - ) - ) - async for row in (await conn.stream(query)): - per_sgroup[row['scaling_group']]['using'] += row['occupied_slots'] - - # Per scaling group resource remaining from agents stats. - sgroup_remaining = ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}) - query = ( - sa.select([agents.c.available_slots, agents.c.occupied_slots, agents.c.scaling_group]) - .select_from(agents) - .where( - (agents.c.status == AgentStatus.ALIVE) & - (agents.c.scaling_group.in_(sgroup_names)) - ) + # Take resources per sgroup. + per_sgroup, sgroup_remaining, agent_slots = await check_scaling_group_resource( + conn, request, params, + domain_name, group_id, + access_key, known_slot_types ) - agent_slots = [] - async for row in (await conn.stream(query)): - remaining = row['available_slots'] - row['occupied_slots'] - remaining += ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}) - sgroup_remaining += remaining - agent_slots.append(remaining) - per_sgroup[row['scaling_group']]['remaining'] += remaining # Take maximum allocatable resources per sgroup. for sgname, sgfields in per_sgroup.items(): @@ -255,13 +202,9 @@ async def check_presets(request: web.Request, params: Any) -> web.Response: }) # Return group resource status as NaN if not allowed. - group_resource_visibility = \ - await root_ctx.shared_config.get_raw('config/api/resources/group_resource_visibility') - group_resource_visibility = t.ToBool().check(group_resource_visibility) - if not group_resource_visibility: - group_limits = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) - group_occupied = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) - group_remaining = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) + group_limits, group_occupied, group_remaining = await check_group_resource( + root_ctx, t, known_slot_types + ) resp['keypair_limits'] = keypair_limits.to_json() resp['keypair_using'] = keypair_occupied.to_json() @@ -274,6 +217,106 @@ async def check_presets(request: web.Request, params: Any) -> web.Response: return web.json_response(resp, status=200) +@atomic +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['group', 'name'], default='default'): t.String, + })) +async def check_group(request: web.Request, params: Any) -> web.Response: + """ + Returns the list of specific group's available resources. + + :param scaling_group: If not None, get available resources of specific scaling group. + :param group: Get available resources of specific group (project) and enumerate them. + """ + root_ctx: RootContext = request.app['_root.context'] + try: + access_key = request['keypair']['access_key'] + domain_name = request['user']['domain_name'] + # TODO: uncomment when we implement scaling group. + # scaling_group = request.query.get('scaling_group') + # assert scaling_group is not None, 'scaling_group parameter is missing.' + except (json.decoder.JSONDecodeError, AssertionError) as e: + raise InvalidAPIParameters(extra_msg=str(e.args[0])) + known_slot_types = await root_ctx.shared_config.get_resource_slots() + resp: MutableMapping[str, Any] = { + 'scaling_group_remaining': None, + 'scaling_groups': None, + } + log.info("RESOURCE.CHECK_GROUP(ak:{}, g:{})", + request['keypair']['access_key'], params['group']) + + async with root_ctx.db.begin_readonly() as conn: + # Check group resource limit and get group_id. + row = await get_groups_info_by_row(conn, request, params, domain_name) + group_id = row['id'] + group_resource_visibility = \ + await root_ctx.shared_config.get_raw('config/api/resources/group_resource_visibility') + group_resource_visibility = t.ToBool().check(group_resource_visibility) + group_limits, group_occupied, group_remaining = await check_group_resource( + conn, row, known_slot_types, + ) + + # 8<---- TODO: work-in-progress from here + + sgroups = await query_allowed_sgroups(conn, domain_name, group_id, access_key) + for sgroup in sgroups: + sgroup_capacity, sgroup_remaining = await check_scaling_group_resource( + conn, sgroup['name'], known_slot_types, + ) + per_sgroup = { + sgroup['name']: { + 'using': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}), + 'remaining': ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}), + } for sgroup in sgroups + } + + # Take maximum allocatable resources per sgroup. + for sgname, sgfields in per_sgroup.items(): + for rtype, slots in sgfields.items(): + per_sgroup[sgname][rtype] = slots.to_json() # type: ignore # it's serialization + + # 8<---- TODO: work-in-progress until here + + resp['limits'] = group_limits.to_json() + resp['occupied'] = group_occupied.to_json() + resp['remaining'] = group_remaining.to_json() + resp['scaling_groups'] = per_sgroup + return web.json_response(resp, status=200) + + +@atomic +@server_status_required(READ_ALLOWED) +@auth_required +@check_api_params( + t.Dict({ + tx.AliasedKey(['scaling_group', 'name'], default='default'): t.String, + })) +async def check_scaling_group(request: web.Request, params: Any) -> web.Response: + """ + Returns the list of specific group's available resources. + + :param scaling_group: If not None, get available resources of specific scaling group. + :param group: Get available resources of specific group (project) and enumerate them. + """ + root_ctx: RootContext = request.app['_root.context'] + known_slot_types = await root_ctx.shared_config.get_resource_slots() + log.info("RESOURCE.CHECK_SCALING_GROUP (ak:{}, sg:{})", + request['keypair']['access_key'], params['scaling_group']) + async with root_ctx.db.begin_readonly() as conn: + # Check group resource limit and get group_id. + sgroup_capacity, sgroup_remaining = await check_scaling_group_resource( + conn, params['scaling_group'], known_slot_types + ) + # TODO: include capacity in the response (when queried by super-admin)? + resp = { + "remaining": sgroup_remaining.to_json(), + } + return web.json_response(resp, status=200) + + @server_status_required(READ_ALLOWED) @superadmin_required @atomic @@ -284,7 +327,7 @@ async def recalculate_usage(request: web.Request) -> web.Response: Those two values are sometimes out of sync. In that case, calling this API re-calculates the values for running containers and updates them in DB. """ - log.info('RECALCULATE_USAGE ()') + log.info('RESOURCE.RECALCULATE_USAGE ()') root_ctx: RootContext = request.app['_root.context'] await root_ctx.registry.recalc_resource_usage() return web.json_response({}, status=200) @@ -436,7 +479,7 @@ async def usage_per_month(request: web.Request, params: Any) -> web.Response: :param group_ids: If not None, query containers only in those groups. :param month: The year-month to query usage statistics. ex) "202006" to query for Jun 2020 """ - log.info('USAGE_PER_MONTH (g:[{}], month:{})', + log.info('RESOURCE.USAGE_PER_MONTH (g:[{}], month:{})', ','.join(params['group_ids']), params['month']) root_ctx: RootContext = request.app['_root.context'] local_tz = root_ctx.shared_config['system']['timezone'] @@ -483,7 +526,7 @@ async def usage_per_period(request: web.Request, params: Any) -> web.Response: raise InvalidAPIParameters(extra_msg='Invalid date values') if end_date <= start_date: raise InvalidAPIParameters(extra_msg='end_date must be later than start_date.') - log.info('USAGE_PER_MONTH (g:{}, start_date:{}, end_date:{})', + log.info('RESOURCE.USAGE_PER_MONTH (g:{}, start_date:{}, end_date:{})', group_id, start_date, end_date) group_ids = [group_id] if group_id is not None else None resp = await get_container_stats_for_period(request, start_date, end_date, group_ids=group_ids) @@ -608,7 +651,7 @@ async def user_month_stats(request: web.Request) -> web.Response: """ access_key = request['keypair']['access_key'] user_uuid = request['user']['uuid'] - log.info('USER_LAST_MONTH_STATS (ak:{}, u:{})', access_key, user_uuid) + log.info('RESOURCE.USER_LAST_MONTH_STATS (ak:{}, u:{})', access_key, user_uuid) stats = await get_time_binned_monthly_stats(request, user_uuid=user_uuid) return web.json_response(stats, status=200) @@ -620,7 +663,7 @@ async def admin_month_stats(request: web.Request) -> web.Response: Return time-binned (15 min) stats for all terminated sessions over last 30 days. """ - log.info('ADMIN_LAST_MONTH_STATS ()') + log.info('RESOURCE.ADMIN_LAST_MONTH_STATS ()') stats = await get_time_binned_monthly_stats(request, user_uuid=None) return web.json_response(stats, status=200) @@ -657,7 +700,7 @@ async def get_watcher_info(request: web.Request, agent_id: str) -> dict: tx.AliasedKey(['agent_id', 'agent']): t.String, })) async def get_watcher_status(request: web.Request, params: Any) -> web.Response: - log.info('GET_WATCHER_STATUS ()') + log.info('RESOURCE.WATCHER.GET_STATUS (ag:{})', params['agent_id']) watcher_info = await get_watcher_info(request, params['agent_id']) connector = aiohttp.TCPConnector() async with aiohttp.ClientSession(connector=connector) as sess: @@ -679,7 +722,7 @@ async def get_watcher_status(request: web.Request, params: Any) -> web.Response: tx.AliasedKey(['agent_id', 'agent']): t.String, })) async def watcher_agent_start(request: web.Request, params: Any) -> web.Response: - log.info('WATCHER_AGENT_START ()') + log.info('RESOURCE.WATCHER.AGENT.START (ag:{})', params['agent_id']) watcher_info = await get_watcher_info(request, params['agent_id']) connector = aiohttp.TCPConnector() async with aiohttp.ClientSession(connector=connector) as sess: @@ -702,7 +745,7 @@ async def watcher_agent_start(request: web.Request, params: Any) -> web.Response tx.AliasedKey(['agent_id', 'agent']): t.String, })) async def watcher_agent_stop(request: web.Request, params: Any) -> web.Response: - log.info('WATCHER_AGENT_STOP ()') + log.info('RESOURCE.WATCHER.AGENT.STOP (ag:{})', params['agent_id']) watcher_info = await get_watcher_info(request, params['agent_id']) connector = aiohttp.TCPConnector() async with aiohttp.ClientSession(connector=connector) as sess: @@ -725,7 +768,7 @@ async def watcher_agent_stop(request: web.Request, params: Any) -> web.Response: tx.AliasedKey(['agent_id', 'agent']): t.String, })) async def watcher_agent_restart(request: web.Request, params: Any) -> web.Response: - log.info('WATCHER_AGENT_RESTART ()') + log.info('RESOURCE.WATCHER.AGENT.RESTART (ag:{})', params['agent_id']) watcher_info = await get_watcher_info(request, params['agent_id']) connector = aiohttp.TCPConnector() async with aiohttp.ClientSession(connector=connector) as sess: @@ -748,6 +791,8 @@ def create_app(default_cors_options: CORSOptions) -> Tuple[web.Application, Iter add_route = app.router.add_route cors.add(add_route('GET', '/presets', list_presets)) cors.add(add_route('POST', '/check-presets', check_presets)) + cors.add(add_route('GET', '/group', check_group)) + cors.add(add_route('GET', '/scaling-group', check_scaling_group)) cors.add(add_route('POST', '/recalculate-usage', recalculate_usage)) cors.add(add_route('GET', '/usage/month', usage_per_month)) cors.add(add_route('GET', '/usage/period', usage_per_period)) diff --git a/src/ai/backend/manager/models/resource_preset.py b/src/ai/backend/manager/models/resource_preset.py index f2f8e5a6b..1236bbaf5 100644 --- a/src/ai/backend/manager/models/resource_preset.py +++ b/src/ai/backend/manager/models/resource_preset.py @@ -5,15 +5,24 @@ Any, Dict, Sequence, + Mapping, + Tuple, TYPE_CHECKING, ) +import uuid +from aiohttp import web import graphene import sqlalchemy as sa from sqlalchemy.engine.row import Row +from sqlalchemy.ext.asyncio import AsyncConnection as SAConnection +from decimal import Decimal from ai.backend.common.logging import BraceStyleAdapter from ai.backend.common.types import ResourceSlot +from ..api.exceptions import ( + InvalidAPIParameters, +) from .base import ( metadata, BigInt, BinarySize, ResourceSlotColumn, simple_db_mutate, @@ -21,8 +30,20 @@ set_if_set, batch_result, ) +from .agent import ( + agents, AgentStatus, +) from .user import UserRole - +from .group import ( + groups, + association_groups_users, +) +from .scaling_group import query_allowed_sgroups +from .kernel import ( + AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES, + kernels, +) +from .domain import domains if TYPE_CHECKING: from .gql import GraphQueryContext @@ -34,6 +55,9 @@ 'CreateResourcePreset', 'ModifyResourcePreset', 'DeleteResourcePreset', + 'get_groups_info_by_row', + 'check_scaling_group_resource', + 'check_group_resource', ) @@ -193,3 +217,100 @@ async def mutate( .where(resource_presets.c.name == name) ) return await simple_db_mutate(cls, info.context, delete_query) + + +async def get_groups_info_by_row( + conn: SAConnection, + request: web.Request, + params: Any, + domain_name: str +) -> Row: + """ + Returns row that has id and total resource slots in group. + """ + j = sa.join( + groups, association_groups_users, + association_groups_users.c.group_id == groups.c.id, + ) + query = ( + sa.select([groups.c.id, groups.c.total_resource_slots]) + .select_from(j) + .where( + (association_groups_users.c.user_id == request['user']['uuid']) & + (groups.c.name == params['group']) & + (domains.c.name == domain_name) + ) + ) + result = await conn.execute(query) + row = result.first() + + return row + + +async def check_scaling_group_resource( + conn: SAConnection, + sgroup_name: str, + known_slot_types: Mapping[str, str], +) -> Tuple[ResourceSlot, ResourceSlot]: + """ + Returns scaling group resource, scaling group resource using from resource occupying kernels, + and scaling group resource remaining from agents stats as tuple. + """ + sgroup_capacity = ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}) + sgroup_remaining = ResourceSlot({k: Decimal(0) for k in known_slot_types.keys()}) + query = ( + sa.select([agents.c.available_slots, agents.c.occupied_slots]) + .select_from(agents) + .where( + (agents.c.status == AgentStatus.ALIVE) & + (agents.c.scaling_group == sgroup_name) + ) + ) + async for row in (await conn.stream(query)): + sgroup_capacity += row['available_slots'] + sgroup_remaining += row['available_slots'] - row['occupied_slots'] + return sgroup_capacity, sgroup_remaining + + +async def check_group_resource( + conn: SAConnection, + group: Row, # TODO: refactor as ORM-based Group + known_slot_types: Mapping, + *, + group_resource_visibility: bool = True, +) -> Tuple: + """ + Returns limits, occupied, and remaining status of groups resource as tuple. + """ + + # TODO: work-in-progress + + group_resource_slots = group.total_resource_slots + if not group_resource_visibility: + # Hide the group resource config. + group_limits = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) + group_occupied = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) + group_remaining = ResourceSlot({k: Decimal('NaN') for k in known_slot_types.keys()}) + else: + group_resource_policy = { + 'total_resource_slots': group_resource_slots, + 'default_for_unspecified': DefaultForUnspecified.UNLIMITED + } + group_limits = ResourceSlot.from_policy(group_resource_policy, known_slot_types) + group_occupied = await root_ctx.registry.get_group_occupancy(group.id, conn=conn) + group_remaining = group_limits - group_occupied + + # Per scaling group resource using from resource occupying kernels. + query = ( + sa.select([kernels.c.occupied_slots, kernels.c.scaling_group]) + .select_from(kernels) + .where( + (kernels.c.user_uuid == request['user']['uuid']) & + (kernels.c.status.in_(AGENT_RESOURCE_OCCUPYING_KERNEL_STATUSES)) & + (kernels.c.scaling_group.in_(sgroup_names)) + ) + ) + async for row in (await conn.stream(query)): + per_sgroup[row['scaling_group']]['using'] += row['occupied_slots'] + + return group_limits, group_occupied, group_remaining diff --git a/src/ai/backend/manager/server.py b/src/ai/backend/manager/server.py index 00241fc64..4a78c48d4 100644 --- a/src/ai/backend/manager/server.py +++ b/src/ai/backend/manager/server.py @@ -107,6 +107,7 @@ 'v6.20200815', # added standard-compliant /admin/gql endpoint + # added /resources/available-resources API # deprecated /admin/graphql endpoint 'v6.20210815', ])