Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mlos_bench/mlos_bench/environments/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tunables.tunable_types import TunableValue
from mlos_bench.util import instantiate_from_config, merge_parameters
from mlos_bench.util import instantiate_from_config, merge_parameters, sanitize_config

if TYPE_CHECKING:
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
Expand Down Expand Up @@ -174,7 +174,11 @@ def __init__( # pylint: disable=too-many-arguments
_LOG.debug("Parameters for '%s' :: %s", name, self._params)

if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2))
_LOG.debug(
"Config for: '%s'\n%s",
name,
json.dumps(sanitize_config(self.config), indent=2),
)

def _validate_json_config(self, config: dict, name: str) -> None:
"""Reconstructs a basic json config that this class might have been instantiated
Expand Down
4 changes: 2 additions & 2 deletions mlos_bench/mlos_bench/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from mlos_bench.storage.base_storage import Storage
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tunables.tunable_types import TunableValue
from mlos_bench.util import try_parse_val
from mlos_bench.util import sanitize_config, try_parse_val

_LOG_LEVEL = logging.INFO
_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s"
Expand Down Expand Up @@ -478,7 +478,7 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> dict[str, TunableValue]:
# other CLI options to use as common python/json variable replacements.
config = {k.replace("-", "_"): v for k, v in config.items()}

_LOG.debug("Parsed config: %s", config)
_LOG.debug("Parsed config: %s", sanitize_config(config))
return config

def _load_config(
Expand Down
20 changes: 16 additions & 4 deletions mlos_bench/mlos_bench/services/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.types.bound_method import BoundMethod
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.util import instantiate_from_config
from mlos_bench.util import instantiate_from_config, sanitize_config

_LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,9 +99,21 @@ def __init__(
self._config_loader_service = parent

if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2))
_LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2))
_LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
_LOG.debug(
"Service: %s Config:\n%s",
self,
json.dumps(sanitize_config(self.config), indent=2),
)
_LOG.debug(
"Service: %s Globals:\n%s",
self,
json.dumps(sanitize_config(global_config or {}), indent=2),
)
_LOG.debug(
"Service: %s Parent: %s",
self,
parent.pprint() if parent else None,
)

@staticmethod
def merge_methods(
Expand Down
10 changes: 6 additions & 4 deletions mlos_bench/mlos_bench/services/config_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@
if any(c in json for c in ("{", "[")):
# If the path contains braces, it is likely already a json string,
# so just parse it.
_LOG.info("Load config from json string: %s", json)
_LOG.info("Load config from json string: %s", sanitize_config(json))
try:
config: Any = json5.loads(json)
except ValueError as ex:
_LOG.error("Failed to parse config from JSON string: %s", json)
_LOG.error("Failed to parse config from JSON string: %s", sanitize_config(json))
raise ValueError(f"Failed to parse config from JSON string: {json}") from ex
else:
json = self.resolve_path(json)
Expand Down Expand Up @@ -225,7 +225,7 @@
# (e.g. Azure ARM templates).
del config["$schema"]
else:
_LOG.warning("Config %s is not validated against a schema.", json)
_LOG.warning("Config %s is not validated against a schema.", sanitize_config(json))
return config # type: ignore[no-any-return]

def prepare_class_load(
Expand Down Expand Up @@ -707,7 +707,9 @@
--------
mlos_bench.services : Examples of service configurations.
"""
_LOG.info("Load services: %s parent: %s", jsons, parent.__class__.__name__)
_LOG.info(
"Load services: %s parent: %s", sanitize_config(jsons), parent.__class__.__name__
)
service = Service({}, global_config, parent)
for json in jsons:
config = self.load_config(json, ConfigSchema.SERVICE)
Expand Down
6 changes: 3 additions & 3 deletions mlos_bench/mlos_bench/storage/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from mlos_bench.services.base_service import Service
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import get_git_info
from mlos_bench.util import get_git_info, sanitize_config

_LOG = logging.getLogger(__name__)

Expand All @@ -62,7 +62,7 @@ def __init__(
config : dict
Free-format key/value pairs of configuration parameters.
"""
_LOG.debug("Storage config: %s", config)
_LOG.debug("Storage config: %s", sanitize_config(config))
self._validate_json_config(config)
self._service = service
self._config = config.copy()
Expand Down Expand Up @@ -431,7 +431,7 @@ def new_trial(
_config = DictTemplater(config).expand_vars()
assert isinstance(_config, dict)
except ValueError as e:
_LOG.error("Non-serializable config: %s", config, exc_info=e)
_LOG.error("Non-serializable config: %s", sanitize_config(config), exc_info=e)
raise e
return self._new_trial(tunables, ts_start, config)

Expand Down
45 changes: 45 additions & 0 deletions mlos_bench/mlos_bench/tests/test_sanitize_confs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def test_sanitize_config_simple() -> None:
"other": 42,
}
sanitized = sanitize_config(config)
assert isinstance(sanitized, dict)
assert sanitized["username"] == "user1"
assert sanitized["password"] == "[REDACTED]"
assert sanitized["token"] == "[REDACTED]"
Expand All @@ -39,6 +40,7 @@ def test_sanitize_config_nested() -> None:
"api_key": "key",
}
sanitized = sanitize_config(config)
assert isinstance(sanitized, dict)
assert sanitized["outer"]["password"] == "[REDACTED]"
assert sanitized["outer"]["inner"]["token"] == "[REDACTED]"
assert sanitized["outer"]["inner"]["foo"] == "bar"
Expand All @@ -61,7 +63,50 @@ def test_sanitize_config_mixed_types() -> None:
"api_key": {"nested": "val"},
}
sanitized = sanitize_config(config)
assert isinstance(sanitized, dict)
assert sanitized["password"] == "[REDACTED]"
assert sanitized["token"] == "[REDACTED]"
assert sanitized["secret"] == "[REDACTED]"
assert sanitized["api_key"] == "[REDACTED]"


def test_sanitize_config_empty() -> None:
"""Test sanitization of an empty configuration."""
config: dict = {}
sanitized = sanitize_config(config)
assert sanitized == config # Should remain empty dictionary


def test_sanitize_array() -> None:
"""Test sanitization of an array with sensitive keys."""
config = [
{"username": "user1", "password": "pass1"},
{"username": "user2", "password": "pass2"},
]
sanitized = sanitize_config(config)
assert isinstance(sanitized, list)
assert len(sanitized) == 2
assert sanitized[0]["username"] == "user1"
assert sanitized[0]["password"] == "[REDACTED]"
assert sanitized[1]["username"] == "user2"
assert sanitized[1]["password"] == "[REDACTED]"


def test_sanitize_config_with_non_string_values() -> None:
"""Test sanitization with non-string values."""
config = {
"int_value": 42,
"float_value": 3.14,
"bool_value": True,
"none_value": None,
"list_value": [1, "password", 3],
"dict_value": {"key": "value"},
}
sanitized = sanitize_config(config)
assert isinstance(sanitized, dict)
assert sanitized["int_value"] == 42
assert sanitized["float_value"] == 3.14
assert sanitized["bool_value"] is True
assert sanitized["none_value"] is None
assert sanitized["list_value"] == [1, "password", 3] # don't redact raw strings
assert sanitized["dict_value"] == {"key": "value"}
51 changes: 38 additions & 13 deletions mlos_bench/mlos_bench/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def datetime_parser(
return new_datetime_col


def sanitize_config(config: dict[str, Any]) -> dict[str, Any]:
def sanitize_config(config: dict[str, Any] | list[Any] | Any) -> dict[str, Any] | list[Any] | Any:
"""
Sanitize a configuration dictionary by obfuscating potentially sensitive keys.

Expand All @@ -480,16 +480,41 @@ def sanitize_config(config: dict[str, Any]) -> dict[str, Any]:
"""
sanitize_keys = {"password", "secret", "token", "api_key"}

def recursive_sanitize(conf: dict[str, Any]) -> dict[str, Any]:
# Try and parse the config as a JSON string first, if it's a string.
was_json = False
if isinstance(config, str):
try:
config = json.loads(config)
was_json = True
except json.JSONDecodeError:
# If it fails to parse, return the original string.
return config

def recursive_sanitize(
conf: dict[str, Any] | list[Any] | str,
) -> dict[str, Any] | list[Any] | str:
"""Recursively sanitize a dictionary."""
sanitized = {}
for k, v in conf.items():
if k in sanitize_keys:
sanitized[k] = "[REDACTED]"
elif isinstance(v, dict):
sanitized[k] = recursive_sanitize(v) # type: ignore[assignment]
else:
sanitized[k] = v
return sanitized

return recursive_sanitize(config)
if isinstance(conf, list):
return [recursive_sanitize(item) for item in conf]
if isinstance(conf, dict):
sanitized = {}
for k, v in conf.items():
if k in sanitize_keys:
sanitized[k] = "[REDACTED]"
elif isinstance(v, dict):
sanitized[k] = recursive_sanitize(v) # type: ignore[assignment]
elif isinstance(v, list):
sanitized[k] = [
recursive_sanitize(item) for item in v # type: ignore[assignment]
]
else:
sanitized[k] = v
return sanitized
# else, return un altered value (e.g., int, float, str)
return conf

sanitized = recursive_sanitize(config)
if was_json:
# If the original config was a JSON string, return it as a JSON string.
return json.dumps(sanitized, indent=2)
return sanitized
Loading