Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions dask_cloudprovider/generic/vmcluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from distributed.utils import warn_on_duration, cli_keywords

from dask_cloudprovider.utils.socket import is_socket_open
from dask_cloudprovider.utils.config_helper import serialize_custom_config


class VMInterface(ProcessInterface):
Expand All @@ -31,9 +32,7 @@ def __init__(self, docker_args: str = "", extra_bootstrap: list = None, **kwargs
self.docker_args = docker_args
self.extra_bootstrap = extra_bootstrap
self.auto_shutdown = True
self.set_env = 'env DASK_INTERNAL_INHERIT_CONFIG="{}"'.format(
dask.config.serialize(dask.config.global_config)
)
self.set_env = f'env DASK_INTERNAL_INHERIT_CONFIG="{serialize_custom_config()}"'
self.kwargs = kwargs

async def create_vm(self):
Expand Down
32 changes: 32 additions & 0 deletions dask_cloudprovider/utils/config_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import copy
import dask.config


def prune_defaults(cfg: dict, defaults: dict) -> dict:
"""
Recursively remove any key in cfg whose value exactly equals
the corresponding built-in default.
"""
pruned = {}
for key, val in cfg.items():
if key not in defaults:
pruned[key] = val
else:
default_val = defaults[key]
if isinstance(val, dict) and isinstance(default_val, dict):
nested = prune_defaults(val, default_val)
if nested:
pruned[key] = nested
elif val != default_val:
pruned[key] = val
return pruned


def serialize_custom_config() -> str:
"""
Pull out only the user-overrides from global_config and serialize them.
"""
user_cfg = copy.deepcopy(dask.config.global_config)
defaults = dask.config.merge(*dask.config.defaults)
pruned = prune_defaults(user_cfg, defaults)
return dask.config.serialize(pruned)
48 changes: 48 additions & 0 deletions dask_cloudprovider/utils/tests/test_config_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import dask.config

from dask_cloudprovider.utils.config_helper import (
prune_defaults,
serialize_custom_config,
)


def test_prune_defaults_simple():
# Keys matching defaults get dropped; new keys stay
cfg = {"a": 1, "b": 2, "c": 3}
defaults = {"a": 1, "b": 0}
pruned = prune_defaults(cfg, defaults)
assert pruned == {"b": 2, "c": 3}


def test_prune_defaults_nested():
# Nested dicts: only subkeys that differ survive
cfg = {
"outer": {"keep": 41, "drop": 0},
"solo": 99,
}
defaults = {
"outer": {"keep": 42, "drop": 0},
"solo": 0,
}
pruned = prune_defaults(cfg, defaults)
# 'outer.drop' matches default, 'outer.keep' differs; 'solo' differs
assert pruned == {"outer": {"keep": 41}, "solo": 99}


def test_serialize_custom_config(monkeypatch):
# Arrange a fake global_config and defaults
fake_global = {"x": 10, "y": {"a": 1, "b": 0}}
fake_defaults = {"x": 0, "y": {"a": 1, "b": 0}}

# Monkey-patch dask.config
monkeypatch.setattr(dask.config, "global_config", fake_global)
# defaults should be a sequence of dict(s)
monkeypatch.setattr(dask.config, "defaults", (fake_defaults,))

# Serialize the custom config
serialized = serialize_custom_config()
assert isinstance(serialized, str)

# Assert it's valid JSON and only contains overrides (x and nothing under y)
pruned = dask.config.deserialize(serialized)
assert pruned == {"x": 10}
Loading