diff --git a/dask_cloudprovider/aws/ecs.py b/dask_cloudprovider/aws/ecs.py index 79e7bb0d..a3c96fc9 100644 --- a/dask_cloudprovider/aws/ecs.py +++ b/dask_cloudprovider/aws/ecs.py @@ -558,6 +558,16 @@ class ECSCluster(SpecCluster, ConfigMixin): Defaults to `None`, no extra command line arguments. worker_task_kwargs: dict (optional) Additional keyword arguments for the workers ECS task. + worker_health_check_port: int (optional) + Port for the worker health check. Defaults to `8787`. + worker_health_check_interval: int (optional) + Interval for the worker health check in seconds. Defaults to `30`. + worker_health_check_timeout: int (optional) + Timeout for the worker health check in seconds. Defaults to `5`. + worker_health_check_retries: int (optional) + Number of retries for the worker health check. Defaults to `3`. + worker_health_check_start_period: int (optional) + Start period for the worker health check in seconds. Defaults to `60`. n_workers: int (optional) Number of workers to start on cluster creation. @@ -732,6 +742,11 @@ def __init__( worker_extra_args=None, worker_task_definition_arn=None, worker_task_kwargs=None, + worker_health_check_port=8787, + worker_health_check_interval=30, + worker_health_check_timeout=5, + worker_health_check_retries=3, + worker_health_check_start_period=60, n_workers=None, workers_name_start=0, workers_name_step=1, @@ -785,6 +800,11 @@ def __init__( ) self._worker_extra_args = worker_extra_args self._worker_task_kwargs = worker_task_kwargs + self._worker_health_check_port = worker_health_check_port + self._worker_health_check_interval = worker_health_check_interval + self._worker_health_check_timeout = worker_health_check_timeout + self._worker_health_check_retries = worker_health_check_retries + self._worker_health_check_start_period = worker_health_check_start_period self._n_workers = n_workers self._workers_name_start = workers_name_start self._workers_name_step = workers_name_step @@ -859,6 +879,11 @@ async def _start( "worker_gpu", # TODO Detect whether cluster is GPU capable "worker_mem", "worker_nthreads", + "worker_health_check_port", + "worker_health_check_interval", + "worker_health_check_timeout", + "worker_health_check_retries", + "worker_health_check_start_period", "vpc", ]: self.update_attr_from_config(attr=attr, private=True) @@ -1250,6 +1275,28 @@ async def _create_worker_task_definition_arn(self): resource_requirements.append( {"type": "GPU", "value": str(self._worker_gpu)} ) + + worker_command = [ + "dask-cuda-worker" if self._worker_gpu else "dask-worker", + "--nthreads", + "{}".format( + max(int(self._worker_cpu / 1024), 1) + if self._worker_nthreads is None + else self._worker_nthreads + ), + "--memory-limit", + "{}MB".format(int(self._worker_mem)), + "--death-timeout", + "60", + ] + + # Add dashboard address if not already specified in extra_args + if not any("--dashboard-address" in s for s in (self._worker_extra_args or [])): + worker_command.extend(["--dashboard-address", f":{self._worker_health_check_port}"]) + + if self._worker_extra_args: + worker_command.extend(self._worker_extra_args) + async with self._client("ecs") as ecs: response = await ecs.register_task_definition( family="{}-{}".format(self.cluster_name, "worker"), @@ -1265,24 +1312,14 @@ async def _create_worker_task_definition_arn(self): "memoryReservation": self._worker_mem, "resourceRequirements": resource_requirements, "essential": True, - "command": [ - "dask-cuda-worker" if self._worker_gpu else "dask-worker", - "--nthreads", - "{}".format( - max(int(self._worker_cpu / 1024), 1) - if self._worker_nthreads is None - else self._worker_nthreads - ), - "--memory-limit", - "{}MB".format(int(self._worker_mem)), - "--death-timeout", - "60", - ] - + ( - list() - if not self._worker_extra_args - else self._worker_extra_args - ), + "command": worker_command, + "healthCheck": { + "command": ["CMD-SHELL", f"curl -f http://localhost:{self._worker_health_check_port}/info || exit 1"], + "interval": self._worker_health_check_interval, + "timeout": self._worker_health_check_timeout, + "retries": self._worker_health_check_retries, + "startPeriod": self._worker_health_check_start_period + }, "ulimits": [ { "name": "nofile", @@ -1367,7 +1404,7 @@ def _check_scheduler_port_config(self): def _check_scheduler_tls_config(self): scheduler_has_tls_config = any( map( - lambda arg: type(arg) == "str" and arg.startswith("--tls"), + lambda arg: isinstance(arg, str) and arg.startswith("--tls"), self._scheduler_extra_args or [], ) )