Skip to content

Commit c343d35

Browse files
[Openstack Cloud Provider] Implemented worker_command and worker_threads for OpenStackCluster object (#463)
* added worker_command parameter to overwrite command that workers should run when starting * fixed the examples/OpenstackCluster-scorepredict.ipynb * added the new configuration paramaters under cloudprovider.yaml * we are not using json.
1 parent db8f20a commit c343d35

File tree

3 files changed

+68
-2
lines changed

3 files changed

+68
-2
lines changed

dask_cloudprovider/cloudprovider.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ cloudprovider:
160160
external_network_id: null # The ID of the external network used for assigning floating IPs. List available external networks using: `openstack network list --external`
161161
create_floating_ip: false # Specifies whether to assign a floating IP to each instance, enabling external access. Set to `True` if external connectivity is needed.
162162
docker_image: "daskdev/dask:latest" # docker image to use
163+
worker_threads: 2 # The number of threads to use on each worker.
164+
worker_command: null # str (optional) The command workers should run when starting. for example, ``dask-cuda-worker`` on GPU-enabled instances.
165+
163166

164167
nebius:
165168
token: null # iam token for interacting with the Nebius AI Cloud

dask_cloudprovider/openstack/instances.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(
3232
docker_image: str = None,
3333
env_vars: str = None,
3434
extra_bootstrap: str = None,
35+
worker_threads: int = None,
36+
worker_command: str = None,
3537
**kwargs,
3638
):
3739
super().__init__(**kwargs)
@@ -42,6 +44,8 @@ def __init__(
4244
self.size = size
4345
self.image = image
4446
self.env_vars = env_vars
47+
self.worker_threads = worker_threads
48+
self.worker_command = worker_command
4549
self.bootstrap = True
4650
self.docker_image = docker_image
4751
self.extra_bootstrap = extra_bootstrap
@@ -226,6 +230,44 @@ async def start_scheduler(self):
226230
class OpenStackWorker(WorkerMixin, OpenStackInstance):
227231
"""Worker running on a OpenStack Instance."""
228232

233+
def __init__(
234+
self,
235+
scheduler: str,
236+
*args,
237+
worker_module: str = None,
238+
worker_class: str = None,
239+
worker_options: dict = {},
240+
**kwargs,
241+
):
242+
super().__init__(
243+
scheduler=scheduler,
244+
*args,
245+
worker_module=worker_module,
246+
worker_class=worker_class,
247+
worker_options=worker_options,
248+
**kwargs,
249+
)
250+
251+
# Select scheduler address (external or internal)
252+
if self.config.get("create_floating_ip", True):
253+
scheduler_ip = self.cluster.scheduler_external_ip
254+
else:
255+
scheduler_ip = self.cluster.scheduler_internal_ip
256+
scheduler_address = f"{self.cluster.protocol}://{scheduler_ip}:{self.cluster.scheduler_port}"
257+
258+
# If user provides worker_command, override the start of the command
259+
if self.worker_command:
260+
# This is only for custom worker_command overrides
261+
cmd = (
262+
self.worker_command if isinstance(self.worker_command, list)
263+
else self.worker_command.split()
264+
)
265+
self.command = " ".join([self.set_env] + cmd + [scheduler_address])
266+
267+
async def start(self):
268+
self.cluster._log(f"Creating worker instance {self.name}")
269+
await self.create_vm()
270+
self.status = Status.running
229271

230272
class OpenStackCluster(VMCluster):
231273
"""Cluster running on Openstack VM Instances
@@ -298,6 +340,21 @@ class OpenStackCluster(VMCluster):
298340
Params to be passed to the worker class.
299341
See :class:`distributed.worker.Worker` for default worker class.
300342
If you set ``worker_module`` then refer to the docstring for the custom worker class.
343+
worker_threads: int
344+
The number of threads to use on each worker.
345+
worker_command : str (optional)
346+
The command workers should run when starting. By default this will be
347+
``python -m distributed.cli.dask_spec``, but you can override it—for example, to
348+
``dask-cuda-worker`` on GPU-enabled instances.
349+
350+
Example
351+
-------
352+
353+
worker_command=[
354+
"dask worker",
355+
"--nthreads", "4",
356+
"--memory-limit", "16GB",
357+
]
301358
scheduler_options: dict
302359
Params to be passed to the scheduler class.
303360
See :class:`distributed.scheduler.Scheduler`.
@@ -355,6 +412,8 @@ def __init__(
355412
docker_image: str = None,
356413
debug: bool = False,
357414
bootstrap: bool = True,
415+
worker_threads: int = 2,
416+
worker_command: str = None,
358417
**kwargs,
359418
):
360419
self.config = dask.config.get("cloudprovider.openstack", {})
@@ -364,13 +423,17 @@ def __init__(
364423
self.bootstrap = (
365424
bootstrap if bootstrap is not None else self.config.get("bootstrap")
366425
)
426+
self.worker_threads = worker_threads or self.config.get("worker_threads")
427+
self.worker_command = worker_command or self.config.get("worker_command")
367428
self.options = {
368429
"cluster": self,
369430
"config": self.config,
370431
"region": region if region is not None else self.config.get("region"),
371432
"size": size if size is not None else self.config.get("size"),
372433
"image": image if image is not None else self.config.get("image"),
373434
"docker_image": docker_image or self.config.get("docker_image"),
435+
"worker_command": self.worker_command,
436+
"worker_threads": self.worker_threads,
374437
}
375438
self.scheduler_options = {**self.options}
376439
self.worker_options = {**self.options}

examples/OpenstackCluster-scorepredict.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,14 @@
5959
},
6060
{
6161
"cell_type": "code",
62-
"execution_count": 49,
62+
"execution_count": null,
6363
"id": "5b745083-bb26-4fe9-a4f6-f7d049628a79",
6464
"metadata": {},
6565
"outputs": [],
6666
"source": [
6767
"import dask\n",
6868
"import dask_cloudprovider\n",
69-
"from instances import OpenStackCluster\n",
69+
"from dask_cloudprovider.openstack import OpenStackCluster\n",
7070
"from dask.distributed import Client, progress"
7171
]
7272
},

0 commit comments

Comments
 (0)