Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 56 additions & 19 deletions dask_cloudprovider/aws/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,16 @@ class ECSCluster(SpecCluster, ConfigMixin):
Defaults to `None`, no extra command line arguments.
worker_task_kwargs: dict (optional)
Additional keyword arguments for the workers ECS task.
worker_health_check_port: int (optional)
Port for the worker health check. Defaults to `8787`.
worker_health_check_interval: int (optional)
Interval for the worker health check in seconds. Defaults to `30`.
worker_health_check_timeout: int (optional)
Timeout for the worker health check in seconds. Defaults to `5`.
worker_health_check_retries: int (optional)
Number of retries for the worker health check. Defaults to `3`.
worker_health_check_start_period: int (optional)
Start period for the worker health check in seconds. Defaults to `60`.
n_workers: int (optional)
Number of workers to start on cluster creation.

Expand Down Expand Up @@ -732,6 +742,11 @@ def __init__(
worker_extra_args=None,
worker_task_definition_arn=None,
worker_task_kwargs=None,
worker_health_check_port=8787,
worker_health_check_interval=30,
worker_health_check_timeout=5,
worker_health_check_retries=3,
worker_health_check_start_period=60,
n_workers=None,
workers_name_start=0,
workers_name_step=1,
Expand Down Expand Up @@ -785,6 +800,11 @@ def __init__(
)
self._worker_extra_args = worker_extra_args
self._worker_task_kwargs = worker_task_kwargs
self._worker_health_check_port = worker_health_check_port
self._worker_health_check_interval = worker_health_check_interval
self._worker_health_check_timeout = worker_health_check_timeout
self._worker_health_check_retries = worker_health_check_retries
self._worker_health_check_start_period = worker_health_check_start_period
self._n_workers = n_workers
self._workers_name_start = workers_name_start
self._workers_name_step = workers_name_step
Expand Down Expand Up @@ -859,6 +879,11 @@ async def _start(
"worker_gpu", # TODO Detect whether cluster is GPU capable
"worker_mem",
"worker_nthreads",
"worker_health_check_port",
"worker_health_check_interval",
"worker_health_check_timeout",
"worker_health_check_retries",
"worker_health_check_start_period",
"vpc",
]:
self.update_attr_from_config(attr=attr, private=True)
Expand Down Expand Up @@ -1250,6 +1275,28 @@ async def _create_worker_task_definition_arn(self):
resource_requirements.append(
{"type": "GPU", "value": str(self._worker_gpu)}
)

worker_command = [
"dask-cuda-worker" if self._worker_gpu else "dask-worker",
"--nthreads",
"{}".format(
max(int(self._worker_cpu / 1024), 1)
if self._worker_nthreads is None
else self._worker_nthreads
),
"--memory-limit",
"{}MB".format(int(self._worker_mem)),
"--death-timeout",
"60",
]

# Add dashboard address if not already specified in extra_args
if not any("--dashboard-address" in s for s in (self._worker_extra_args or [])):
worker_command.extend(["--dashboard-address", f":{self._worker_health_check_port}"])

if self._worker_extra_args:
worker_command.extend(self._worker_extra_args)

async with self._client("ecs") as ecs:
response = await ecs.register_task_definition(
family="{}-{}".format(self.cluster_name, "worker"),
Expand All @@ -1265,24 +1312,14 @@ async def _create_worker_task_definition_arn(self):
"memoryReservation": self._worker_mem,
"resourceRequirements": resource_requirements,
"essential": True,
"command": [
"dask-cuda-worker" if self._worker_gpu else "dask-worker",
"--nthreads",
"{}".format(
max(int(self._worker_cpu / 1024), 1)
if self._worker_nthreads is None
else self._worker_nthreads
),
"--memory-limit",
"{}MB".format(int(self._worker_mem)),
"--death-timeout",
"60",
]
+ (
list()
if not self._worker_extra_args
else self._worker_extra_args
),
"command": worker_command,
"healthCheck": {
"command": ["CMD-SHELL", f"curl -f http://localhost:{self._worker_health_check_port}/info || exit 1"],
"interval": self._worker_health_check_interval,
"timeout": self._worker_health_check_timeout,
"retries": self._worker_health_check_retries,
"startPeriod": self._worker_health_check_start_period
},
"ulimits": [
{
"name": "nofile",
Expand Down Expand Up @@ -1367,7 +1404,7 @@ def _check_scheduler_port_config(self):
def _check_scheduler_tls_config(self):
scheduler_has_tls_config = any(
map(
lambda arg: type(arg) == "str" and arg.startswith("--tls"),
lambda arg: isinstance(arg, str) and arg.startswith("--tls"),
self._scheduler_extra_args or [],
)
)
Expand Down
Loading