Skip to content
1 change: 1 addition & 0 deletions .env-devel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ CLUSTERS_KEEPER_COMPUTATIONAL_BACKEND_DOCKER_IMAGE_TAG=master-github-latest
CLUSTERS_KEEPER_DASK_NTHREADS=0
CLUSTERS_KEEPER_DASK_WORKER_SATURATION=inf
CLUSTERS_KEEPER_EC2_ACCESS=null
CLUSTERS_KEEPER_SSM_ACCESS=null
CLUSTERS_KEEPER_EC2_INSTANCES_PREFIX=""
CLUSTERS_KEEPER_LOGLEVEL=WARNING
CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION=5
Expand Down
3 changes: 3 additions & 0 deletions packages/models-library/src/models_library/clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ class Config(BaseAuthentication.Config):
class NoAuthentication(BaseAuthentication):
type: Literal["none"] = "none"

class Config(BaseAuthentication.Config):
schema_extra: ClassVar[dict[str, Any]] = {"examples": [{"type": "none"}]}


class TLSAuthentication(BaseAuthentication):
type: Literal["tls"] = "tls"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
@router.get("/", include_in_schema=True, response_class=PlainTextResponse)
async def health_check():
# NOTE: sync url in docker/healthcheck.py with this entrypoint!
return f"{__name__}.health_check@{datetime.datetime.now(datetime.timezone.utc).isoformat()}"
return f"{__name__}.health_check@{datetime.datetime.now(datetime.UTC).isoformat()}"


class _ComponentStatus(BaseModel):
Expand All @@ -33,25 +33,34 @@ class _StatusGet(BaseModel):
rabbitmq: _ComponentStatus
ec2: _ComponentStatus
redis_client_sdk: _ComponentStatus
ssm: _ComponentStatus


@router.get("/status", include_in_schema=True, response_model=_StatusGet)
async def get_status(app: Annotated[FastAPI, Depends(get_app)]) -> _StatusGet:
return _StatusGet(
rabbitmq=_ComponentStatus(
is_enabled=is_rabbitmq_enabled(app),
is_responsive=await get_rabbitmq_client(app).ping()
if is_rabbitmq_enabled(app)
else False,
is_responsive=(
await get_rabbitmq_client(app).ping()
if is_rabbitmq_enabled(app)
else False
),
),
ec2=_ComponentStatus(
is_enabled=bool(app.state.ec2_client),
is_responsive=await app.state.ec2_client.ping()
if app.state.ec2_client
else False,
is_responsive=(
await app.state.ec2_client.ping() if app.state.ec2_client else False
),
),
redis_client_sdk=_ComponentStatus(
is_enabled=bool(app.state.redis_client_sdk),
is_responsive=await get_redis_client(app).ping(),
),
ssm=_ComponentStatus(
is_enabled=(app.state.ssm_client is not None),
is_responsive=(
await app.state.ssm_client.ping() if app.state.ssm_client else False
),
),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Final

from aws_library.ec2._models import AWSTagKey, AWSTagValue
from pydantic import parse_obj_as

DOCKER_STACK_DEPLOY_COMMAND_NAME: Final[str] = "private cluster docker deploy"
DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY: Final[AWSTagKey] = parse_obj_as(
AWSTagKey, "io.simcore.clusters-keeper.private_cluster_docker_deploy"
)

USER_ID_TAG_KEY: Final[AWSTagKey] = parse_obj_as(AWSTagKey, "user_id")
WALLET_ID_TAG_KEY: Final[AWSTagKey] = parse_obj_as(AWSTagKey, "wallet_id")
ROLE_TAG_KEY: Final[AWSTagKey] = parse_obj_as(AWSTagKey, "role")
WORKER_ROLE_TAG_VALUE: Final[AWSTagValue] = parse_obj_as(AWSTagValue, "worker")
MANAGER_ROLE_TAG_VALUE: Final[AWSTagValue] = parse_obj_as(AWSTagValue, "manager")
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..modules.ec2 import setup as setup_ec2
from ..modules.rabbitmq import setup as setup_rabbitmq
from ..modules.redis import setup as setup_redis
from ..modules.ssm import setup as setup_ssm
from ..rpc.rpc_routes import setup_rpc_routes
from .settings import ApplicationSettings

Expand Down Expand Up @@ -55,6 +56,7 @@ def create_app(settings: ApplicationSettings) -> FastAPI:
setup_rabbitmq(app)
setup_rpc_routes(app)
setup_ec2(app)
setup_ssm(app)
setup_redis(app)
setup_clusters_management(app)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from settings_library.ec2 import EC2Settings
from settings_library.rabbit import RabbitSettings
from settings_library.redis import RedisSettings
from settings_library.ssm import SSMSettings
from settings_library.tracing import TracingSettings
from settings_library.utils_logging import MixinLoggingSettings
from types_aiobotocore_ec2.literals import InstanceTypeType
Expand All @@ -50,6 +51,21 @@ class Config(EC2Settings.Config):
}


class ClustersKeeperSSMSettings(SSMSettings):
class Config(SSMSettings.Config):
env_prefix = CLUSTERS_KEEPER_ENV_PREFIX

schema_extra: ClassVar[dict[str, Any]] = { # type: ignore[misc]
"examples": [
{
f"{CLUSTERS_KEEPER_ENV_PREFIX}{key}": var
for key, var in example.items()
}
for example in SSMSettings.Config.schema_extra["examples"]
],
}


class WorkersEC2InstancesSettings(BaseCustomSettings):
WORKERS_EC2_INSTANCES_ALLOWED_TYPES: dict[str, EC2InstanceBootSpecific] = Field(
...,
Expand Down Expand Up @@ -183,6 +199,12 @@ class PrimaryEC2InstancesSettings(BaseCustomSettings):
"that take longer than this time will be terminated as sometimes it happens that EC2 machine fail on start.",
)

PRIMARY_EC2_INSTANCES_DOCKER_DEFAULT_ADDRESS_POOL: str = Field(
default="172.20.0.0/14",
description="defines the docker swarm default address pool in CIDR format "
"(see https://docs.docker.com/reference/cli/docker/swarm/init/)",
)

@validator("PRIMARY_EC2_INSTANCES_ALLOWED_TYPES")
@classmethod
def check_valid_instance_names(
Expand Down Expand Up @@ -250,6 +272,10 @@ class ApplicationSettings(BaseCustomSettings, MixinLoggingSettings):
auto_default_from_env=True
)

CLUSTERS_KEEPER_SSM_ACCESS: ClustersKeeperSSMSettings | None = Field(
auto_default_from_env=True
)

CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES: PrimaryEC2InstancesSettings | None = Field(
auto_default_from_env=True
)
Expand Down Expand Up @@ -285,9 +311,11 @@ class ApplicationSettings(BaseCustomSettings, MixinLoggingSettings):
"(default to seconds, or see https://pydantic-docs.helpmanual.io/usage/types/#datetime-types for string formating)",
)

CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION: NonNegativeInt = Field(
default=5,
description="Max number of missed heartbeats before a cluster is terminated",
CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION: NonNegativeInt = (
Field(
default=5,
description="Max number of missed heartbeats before a cluster is terminated",
)
)

CLUSTERS_KEEPER_COMPUTATIONAL_BACKEND_DOCKER_IMAGE_TAG: str = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async def _get_primary_ec2_params(
ec2_instance_types: list[
EC2InstanceType
] = await ec2_client.get_ec2_instance_capabilities(
instance_type_names=[ec2_type_name]
instance_type_names={ec2_type_name}
)
assert ec2_instance_types # nosec
assert len(ec2_instance_types) == 1 # nosec
Expand All @@ -72,15 +72,7 @@ async def create_cluster(
tags=creation_ec2_tags(app_settings, user_id=user_id, wallet_id=wallet_id),
startup_script=create_startup_script(
app_settings,
cluster_machines_name_prefix=get_cluster_name(
app_settings, user_id=user_id, wallet_id=wallet_id, is_manager=False
),
ec2_boot_specific=ec2_instance_boot_specs,
additional_custom_tags={
AWSTagKey("user_id"): AWSTagValue(f"{user_id}"),
AWSTagKey("wallet_id"): AWSTagValue(f"{wallet_id}"),
AWSTagKey("role"): AWSTagValue("worker"),
},
),
ami_id=ec2_instance_boot_specs.ami_id,
key_name=app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES.PRIMARY_EC2_INSTANCES_KEY_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,40 @@

import arrow
from aws_library.ec2 import AWSTagKey, EC2InstanceData
from aws_library.ec2._models import AWSTagValue
from fastapi import FastAPI
from models_library.users import UserID
from models_library.wallets import WalletID
from pydantic import parse_obj_as
from servicelib.logging_utils import log_catch

from servicelib.utils import limited_gather

from ..constants import (
DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY,
DOCKER_STACK_DEPLOY_COMMAND_NAME,
ROLE_TAG_KEY,
USER_ID_TAG_KEY,
WALLET_ID_TAG_KEY,
WORKER_ROLE_TAG_VALUE,
)
from ..core.settings import get_application_settings
from ..modules.clusters import (
delete_clusters,
get_all_clusters,
get_cluster_workers,
set_instance_heartbeat,
)
from ..utils.clusters import create_deploy_cluster_stack_script
from ..utils.dask import get_scheduler_auth, get_scheduler_url
from ..utils.ec2 import HEARTBEAT_TAG_KEY
from ..utils.ec2 import (
HEARTBEAT_TAG_KEY,
get_cluster_name,
user_id_from_instance_tags,
wallet_id_from_instance_tags,
)
from .dask import is_scheduler_busy, ping_scheduler
from .ec2 import get_ec2_client
from .ssm import get_ssm_client

_logger = logging.getLogger(__name__)

Expand All @@ -42,8 +60,8 @@ def _get_instance_last_heartbeat(instance: EC2InstanceData) -> datetime.datetime
async def _get_all_associated_worker_instances(
app: FastAPI,
primary_instances: Iterable[EC2InstanceData],
) -> list[EC2InstanceData]:
worker_instances = []
) -> set[EC2InstanceData]:
worker_instances: set[EC2InstanceData] = set()
for instance in primary_instances:
assert "user_id" in instance.tags # nosec
user_id = UserID(instance.tags[_USER_ID_TAG_KEY])
Expand All @@ -55,20 +73,20 @@ async def _get_all_associated_worker_instances(
else None
)

worker_instances.extend(
worker_instances.update(
await get_cluster_workers(app, user_id=user_id, wallet_id=wallet_id)
)
return worker_instances


async def _find_terminateable_instances(
app: FastAPI, instances: Iterable[EC2InstanceData]
) -> list[EC2InstanceData]:
) -> set[EC2InstanceData]:
app_settings = get_application_settings(app)
assert app_settings.CLUSTERS_KEEPER_PRIMARY_EC2_INSTANCES # nosec

# get the corresponding ec2 instance data
terminateable_instances: list[EC2InstanceData] = []
terminateable_instances: set[EC2InstanceData] = set()

time_to_wait_before_termination = (
app_settings.CLUSTERS_KEEPER_MAX_MISSED_HEARTBEATS_BEFORE_CLUSTER_TERMINATION
Expand All @@ -82,7 +100,7 @@ async def _find_terminateable_instances(
elapsed_time_since_heartbeat = arrow.utcnow().datetime - last_heartbeat
allowed_time_to_wait = time_to_wait_before_termination
if elapsed_time_since_heartbeat >= allowed_time_to_wait:
terminateable_instances.append(instance)
terminateable_instances.add(instance)
else:
_logger.info(
"%s has still %ss before being terminateable",
Expand All @@ -93,14 +111,14 @@ async def _find_terminateable_instances(
elapsed_time_since_startup = arrow.utcnow().datetime - instance.launch_time
allowed_time_to_wait = startup_delay
if elapsed_time_since_startup >= allowed_time_to_wait:
terminateable_instances.append(instance)
terminateable_instances.add(instance)

# get all terminateable instances associated worker instances
worker_instances = await _get_all_associated_worker_instances(
app, terminateable_instances
)

return terminateable_instances + worker_instances
return terminateable_instances.union(worker_instances)


async def check_clusters(app: FastAPI) -> None:
Expand All @@ -112,6 +130,7 @@ async def check_clusters(app: FastAPI) -> None:
if await ping_scheduler(get_scheduler_url(instance), get_scheduler_auth(app))
}

# set intance heartbeat if scheduler is busy
for instance in connected_intances:
with log_catch(_logger, reraise=False):
# NOTE: some connected instance could in theory break between these 2 calls, therefore this is silenced and will
Expand All @@ -124,6 +143,7 @@ async def check_clusters(app: FastAPI) -> None:
f"{instance.id=} for {instance.tags=}",
)
await set_instance_heartbeat(app, instance=instance)
# clean any cluster that is not doing anything
if terminateable_instances := await _find_terminateable_instances(
app, connected_intances
):
Expand All @@ -138,7 +158,7 @@ async def check_clusters(app: FastAPI) -> None:
for instance in disconnected_instances
if _get_instance_last_heartbeat(instance) is None
}

# remove instances that were starting for too long
if terminateable_instances := await _find_terminateable_instances(
app, starting_instances
):
Expand All @@ -149,7 +169,72 @@ async def check_clusters(app: FastAPI) -> None:
)
await delete_clusters(app, instances=terminateable_instances)

# the other instances are broken (they were at some point connected but now not anymore)
# NOTE: transmit command to start docker swarm/stack if needed
# once the instance is connected to the SSM server,
# use ssm client to send the command to these instances,
# we send a command that contain:
# the docker-compose file in binary,
# the call to init the docker swarm and the call to deploy the stack
instances_in_need_of_deployment = {
i
for i in starting_instances - terminateable_instances
if DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY not in i.tags
}

if instances_in_need_of_deployment:
app_settings = get_application_settings(app)
ssm_client = get_ssm_client(app)
ec2_client = get_ec2_client(app)
instances_in_need_of_deployment_ssm_connection_state = await limited_gather(
*[
ssm_client.is_instance_connected_to_ssm_server(i.id)
for i in instances_in_need_of_deployment
],
reraise=False,
log=_logger,
limit=20,
)
ec2_connected_to_ssm_server = [
i
for i, c in zip(
instances_in_need_of_deployment,
instances_in_need_of_deployment_ssm_connection_state,
strict=True,
)
if c is True
]
started_instances_ready_for_command = ec2_connected_to_ssm_server
if started_instances_ready_for_command:
# we need to send 1 command per machine here, as the user_id/wallet_id changes
for i in started_instances_ready_for_command:
ssm_command = await ssm_client.send_command(
[i.id],
command=create_deploy_cluster_stack_script(
app_settings,
cluster_machines_name_prefix=get_cluster_name(
app_settings,
user_id=user_id_from_instance_tags(i.tags),
wallet_id=wallet_id_from_instance_tags(i.tags),
is_manager=False,
),
additional_custom_tags={
USER_ID_TAG_KEY: i.tags[USER_ID_TAG_KEY],
WALLET_ID_TAG_KEY: i.tags[WALLET_ID_TAG_KEY],
ROLE_TAG_KEY: WORKER_ROLE_TAG_VALUE,
},
),
command_name=DOCKER_STACK_DEPLOY_COMMAND_NAME,
)
await ec2_client.set_instances_tags(
started_instances_ready_for_command,
tags={
DOCKER_STACK_DEPLOY_COMMAND_EC2_TAG_KEY: AWSTagValue(
ssm_command.command_id
),
},
)

# the remaining instances are broken (they were at some point connected but now not anymore)
broken_instances = disconnected_instances - starting_instances
if terminateable_instances := await _find_terminateable_instances(
app, broken_instances
Expand Down
Loading
Loading