Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def main(
def summary(
user_id: Annotated[int, typer.Option(help="filters by the user ID")] = 0,
wallet_id: Annotated[int, typer.Option(help="filters by the wallet ID")] = 0,
as_json: Annotated[bool, typer.Option(help="outputs as json")] = False,
) -> None:
"""Show a summary of the current situation of autoscaled EC2 instances.

Expand All @@ -140,7 +141,9 @@ def summary(

"""

if not asyncio.run(api.summary(state, user_id or None, wallet_id or None)):
if not asyncio.run(
api.summary(state, user_id or None, wallet_id or None, output_json=as_json)
):
raise typer.Exit(1)


Expand All @@ -152,7 +155,7 @@ def cancel_jobs(
typer.Option(help="the wallet ID"),
] = None,
*,
force: Annotated[
abort_in_db: Annotated[
bool,
typer.Option(
help="will also force the job to abort in the database (use only if job is in WAITING FOR CLUSTER/WAITING FOR RESOURCE)"
Expand All @@ -166,23 +169,26 @@ def cancel_jobs(
Keyword Arguments:
user_id -- the user ID
wallet_id -- the wallet ID
abort_in_db -- will also force the job to abort in the database (use only if job is in WAITING FOR CLUSTER/WAITING FOR RESOURCE)
"""
asyncio.run(api.cancel_jobs(state, user_id, wallet_id, force=force))
asyncio.run(api.cancel_jobs(state, user_id, wallet_id, abort_in_db=abort_in_db))


@app.command()
def trigger_cluster_termination(
user_id: Annotated[int, typer.Option(help="the user ID")],
wallet_id: Annotated[int, typer.Option(help="the wallet ID")],
force: Annotated[bool, typer.Option(help="will not ask for confirmation")] = False,
) -> None:
"""this will set the Heartbeat tag on the primary machine to 1 hour, thus ensuring the
clusters-keeper will properly terminate that cluster.

Keyword Arguments:
user_id -- the user ID
wallet_id -- the wallet ID
force -- will not ask for confirmation (VERY RISKY! USE WITH CAUTION!)
"""
asyncio.run(api.trigger_cluster_termination(state, user_id, wallet_id))
asyncio.run(api.trigger_cluster_termination(state, user_id, wallet_id, force=force))


@app.command()
Expand Down
220 changes: 160 additions & 60 deletions scripts/maintenance/computational-clusters/autoscaled_monitor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,61 @@ async def _parse_dynamic_instances(
return dynamic_instances


async def summary(state: AppState, user_id: int | None, wallet_id: int | None) -> bool:
def _print_summary_as_json(
dynamic_instances: list[DynamicInstance],
computational_clusters: list[ComputationalCluster],
) -> None:
result = {
"dynamic_instances": [
{
"name": instance.name,
"ec2_instance_id": instance.ec2_instance.instance_id,
"running_services": [
{
"user_id": service.user_id,
"project_id": service.project_id,
"node_id": service.node_id,
"service_name": service.service_name,
"service_version": service.service_version,
"created_at": service.created_at.isoformat(),
"needs_manual_intervention": service.needs_manual_intervention,
}
for service in instance.running_services
],
"disk_space": instance.disk_space.human_readable(),
}
for instance in dynamic_instances
],
"computational_clusters": [
{
"primary": {
"name": cluster.primary.name,
"ec2_instance_id": cluster.primary.ec2_instance.instance_id,
"user_id": cluster.primary.user_id,
"wallet_id": cluster.primary.wallet_id,
"disk_space": cluster.primary.disk_space.human_readable(),
"last_heartbeat": cluster.primary.last_heartbeat.isoformat(),
},
"workers": [
{
"name": worker.name,
"ec2_instance_id": worker.ec2_instance.instance_id,
"disk_space": worker.disk_space.human_readable(),
}
for worker in cluster.workers
],
"datasets": cluster.datasets,
"tasks": cluster.task_states_to_tasks,
}
for cluster in computational_clusters
],
}
rich.print_json(json.dumps(result))


async def summary(
state: AppState, user_id: int | None, wallet_id: int | None, *, output_json: bool
) -> bool:
# get all the running instances
assert state.ec2_resource_autoscaling
dynamic_instances = await ec2.list_dynamic_instances_from_ec2(
Expand All @@ -422,19 +476,6 @@ async def summary(state: AppState, user_id: int | None, wallet_id: int | None) -
dynamic_autoscaled_instances = await _parse_dynamic_instances(
state, dynamic_instances, state.ssh_key_path, user_id, wallet_id
)
_print_dynamic_instances(
dynamic_autoscaled_instances,
state.environment,
state.ec2_resource_autoscaling.meta.client.meta.region_name,
)

time_threshold = arrow.utcnow().shift(minutes=-30).datetime

dynamic_services_in_error = any(
service.needs_manual_intervention and service.created_at < time_threshold
for instance in dynamic_autoscaled_instances
for service in instance.running_services
)

assert state.ec2_resource_clusters_keeper
computational_instances = await ec2.list_computational_instances_from_ec2(
Expand All @@ -443,10 +484,27 @@ async def summary(state: AppState, user_id: int | None, wallet_id: int | None) -
computational_clusters = await _parse_computational_clusters(
state, computational_instances, state.ssh_key_path, user_id, wallet_id
)
_print_computational_clusters(
computational_clusters,
state.environment,
state.ec2_resource_clusters_keeper.meta.client.meta.region_name,

if output_json:
_print_summary_as_json(dynamic_autoscaled_instances, computational_clusters)

if not output_json:
_print_dynamic_instances(
dynamic_autoscaled_instances,
state.environment,
state.ec2_resource_autoscaling.meta.client.meta.region_name,
)
_print_computational_clusters(
computational_clusters,
state.environment,
state.ec2_resource_clusters_keeper.meta.client.meta.region_name,
)

time_threshold = arrow.utcnow().shift(minutes=-30).datetime
dynamic_services_in_error = any(
service.needs_manual_intervention and service.created_at < time_threshold
for instance in dynamic_autoscaled_instances
for service in instance.running_services
)

return not dynamic_services_in_error
Expand Down Expand Up @@ -504,29 +562,48 @@ async def _list_computational_clusters(
)


async def cancel_jobs( # noqa: C901, PLR0912
state: AppState, user_id: int, wallet_id: int | None, *, force: bool
async def _cancel_all_jobs(
state: AppState,
the_cluster: ComputationalCluster,
*,
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]],
abort_in_db: bool,
) -> None:
# get the theory
computational_tasks = await db.list_computational_tasks_from_db(state, user_id)
rich.print("cancelling all tasks")
for comp_task, dask_task in task_to_dask_job:
if dask_task is not None and dask_task.state != "unknown":
await dask.trigger_job_cancellation_in_scheduler(
state,
the_cluster,
dask_task.job_id,
)
if comp_task is None:
# we need to clear it of the cluster
await dask.remove_job_from_scheduler(
state,
the_cluster,
dask_task.job_id,
)
if comp_task is not None and abort_in_db:
await db.abort_job_in_db(state, comp_task.project_id, comp_task.node_id)

# get the reality
computational_clusters = await _list_computational_clusters(
state, user_id, wallet_id
)
job_id_to_dask_state: dict[TaskId, TaskState] = {}
if computational_clusters:
assert (
len(computational_clusters) == 1
), "too many clusters found! TIP: fix this code or something weird is playing out"
rich.print("cancelled all tasks")

the_cluster = computational_clusters[0]
rich.print(f"{the_cluster.task_states_to_tasks=}")

for job_state, job_ids in the_cluster.task_states_to_tasks.items():
for job_id in job_ids:
job_id_to_dask_state[job_id] = job_state
async def _get_job_id_to_dask_state_from_cluster(
cluster: ComputationalCluster,
) -> dict[TaskId, TaskState]:
job_id_to_dask_state: dict[TaskId, TaskState] = {}
for job_state, job_ids in cluster.task_states_to_tasks.items():
for job_id in job_ids:
job_id_to_dask_state[job_id] = job_state
return job_id_to_dask_state


async def _get_db_task_to_dask_job(
computational_tasks: list[ComputationalTask],
job_id_to_dask_state: dict[TaskId, TaskState],
) -> list[tuple[ComputationalTask | None, DaskTask | None]]:
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]] = []
for task in computational_tasks:
dask_task = None
Expand All @@ -539,6 +616,32 @@ async def cancel_jobs( # noqa: C901, PLR0912
# keep the jobs still in the cluster
for job_id, dask_state in job_id_to_dask_state.items():
task_to_dask_job.append((None, DaskTask(job_id=job_id, state=dask_state)))
return task_to_dask_job


async def cancel_jobs( # noqa: C901, PLR0912
state: AppState, user_id: int, wallet_id: int | None, *, abort_in_db: bool
) -> None:
# get the theory
computational_tasks = await db.list_computational_tasks_from_db(state, user_id)

# get the reality
computational_clusters = await _list_computational_clusters(
state, user_id, wallet_id
)

if computational_clusters:
assert (
len(computational_clusters) == 1
), "too many clusters found! TIP: fix this code or something weird is playing out"

the_cluster = computational_clusters[0]
rich.print(f"{the_cluster.task_states_to_tasks=}")

job_id_to_dask_state = await _get_job_id_to_dask_state_from_cluster(the_cluster)
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]] = (
await _get_db_task_to_dask_job(computational_tasks, job_id_to_dask_state)
)

if not task_to_dask_job:
rich.print("[red]nothing found![/red]")
Expand All @@ -554,27 +657,12 @@ async def cancel_jobs( # noqa: C901, PLR0912
if response == "none":
rich.print("[yellow]not cancelling anything[/yellow]")
elif response == "all":
rich.print("cancelling all tasks")
for comp_task, dask_task in task_to_dask_job:
if dask_task is not None and dask_task.state != "unknown":
await dask.trigger_job_cancellation_in_scheduler(
state,
the_cluster,
dask_task.job_id,
)
if comp_task is None:
# we need to clear it of the cluster
await dask.remove_job_from_scheduler(
state,
the_cluster,
dask_task.job_id,
)
if comp_task is not None and force:
await db.abort_job_in_db(
state, comp_task.project_id, comp_task.node_id
)

rich.print("cancelled all tasks")
await _cancel_all_jobs(
state,
the_cluster,
task_to_dask_job=task_to_dask_job,
abort_in_db=abort_in_db,
)
else:
try:
# Split the response and handle ranges
Expand All @@ -597,7 +685,7 @@ async def cancel_jobs( # noqa: C901, PLR0912
state, the_cluster, dask_task.job_id
)

if comp_task is not None and force:
if comp_task is not None and abort_in_db:
await db.abort_job_in_db(
state, comp_task.project_id, comp_task.node_id
)
Expand All @@ -616,7 +704,7 @@ async def cancel_jobs( # noqa: C901, PLR0912


async def trigger_cluster_termination(
state: AppState, user_id: int, wallet_id: int
state: AppState, user_id: int, wallet_id: int, *, force: bool
) -> None:
assert state.ec2_resource_clusters_keeper
computational_instances = await ec2.list_computational_instances_from_ec2(
Expand All @@ -635,8 +723,20 @@ async def trigger_cluster_termination(
state.environment,
state.ec2_resource_clusters_keeper.meta.client.meta.region_name,
)
if typer.confirm("Are you sure you want to trigger termination of that cluster?"):
if (force is True) or typer.confirm(
"Are you sure you want to trigger termination of that cluster?"
):
the_cluster = computational_clusters[0]

computational_tasks = await db.list_computational_tasks_from_db(state, user_id)
job_id_to_dask_state = await _get_job_id_to_dask_state_from_cluster(the_cluster)
task_to_dask_job: list[tuple[ComputationalTask | None, DaskTask | None]] = (
await _get_db_task_to_dask_job(computational_tasks, job_id_to_dask_state)
)
await _cancel_all_jobs(
state, the_cluster, task_to_dask_job=task_to_dask_job, abort_in_db=force
)

new_heartbeat_tag: TagTypeDef = {
"Key": "last_heartbeat",
"Value": f"{arrow.utcnow().datetime - datetime.timedelta(hours=1)}",
Expand Down
Loading
Loading