Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions mlos_bench/mlos_bench/environments/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from mlos_bench.services.base_service import Service
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tunables.tunable_types import TunableValue
from mlos_bench.util import instantiate_from_config, merge_parameters
from mlos_bench.util import instantiate_from_config, merge_parameters, sanitize_config

if TYPE_CHECKING:
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
Expand Down Expand Up @@ -174,7 +174,11 @@ def __init__( # pylint: disable=too-many-arguments
_LOG.debug("Parameters for '%s' :: %s", name, self._params)

if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2))
_LOG.debug(
"Config for: '%s'\n%s",
name,
json.dumps(sanitize_config(self.config), indent=2),
)

def _validate_json_config(self, config: dict, name: str) -> None:
"""Reconstructs a basic json config that this class might have been instantiated
Expand Down
5 changes: 3 additions & 2 deletions mlos_bench/mlos_bench/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from mlos_bench.storage.base_storage import Storage
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.tunables.tunable_types import TunableValue
from mlos_bench.util import try_parse_val
from mlos_bench.util import sanitize_config, try_parse_val

_LOG_LEVEL = logging.INFO
_LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s"
Expand Down Expand Up @@ -478,7 +478,8 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> dict[str, TunableValue]:
# other CLI options to use as common python/json variable replacements.
config = {k.replace("-", "_"): v for k, v in config.items()}

_LOG.debug("Parsed config: %s", config)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Parsed config: %s", sanitize_config(config))
return config

def _load_config(
Expand Down
20 changes: 16 additions & 4 deletions mlos_bench/mlos_bench/services/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.types.bound_method import BoundMethod
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
from mlos_bench.util import instantiate_from_config
from mlos_bench.util import instantiate_from_config, sanitize_config

_LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,9 +99,21 @@ def __init__(
self._config_loader_service = parent

if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2))
_LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2))
_LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None)
_LOG.debug(
"Service: %s Config:\n%s",
self,
json.dumps(sanitize_config(self.config), indent=2),
)
_LOG.debug(
"Service: %s Globals:\n%s",
self,
json.dumps(sanitize_config(global_config or {}), indent=2),
)
_LOG.debug(
"Service: %s Parent: %s",
self,
parent.pprint() if parent else None,
)

@staticmethod
def merge_methods(
Expand Down
14 changes: 10 additions & 4 deletions mlos_bench/mlos_bench/services/config_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,12 @@ def load_config(
if any(c in json for c in ("{", "[")):
# If the path contains braces, it is likely already a json string,
# so just parse it.
_LOG.info("Load config from json string: %s", json)
if _LOG.isEnabledFor(logging.INFO):
_LOG.info("Load config from json string: %s", sanitize_config(json))
try:
config: Any = json5.loads(json)
except ValueError as ex:
_LOG.error("Failed to parse config from JSON string: %s", json)
_LOG.error("Failed to parse config from JSON string: %s", sanitize_config(json))
raise ValueError(f"Failed to parse config from JSON string: {json}") from ex
else:
json = self.resolve_path(json)
Expand Down Expand Up @@ -225,7 +226,7 @@ def load_config(
# (e.g. Azure ARM templates).
del config["$schema"]
else:
_LOG.warning("Config %s is not validated against a schema.", json)
_LOG.warning("Config %s is not validated against a schema.", sanitize_config(json))
return config # type: ignore[no-any-return]

def prepare_class_load(
Expand Down Expand Up @@ -707,7 +708,12 @@ def load_services(
--------
mlos_bench.services : Examples of service configurations.
"""
_LOG.info("Load services: %s parent: %s", jsons, parent.__class__.__name__)
if _LOG.isEnabledFor(logging.INFO):
_LOG.info(
"Load services: %s parent: %s",
sanitize_config(jsons),
parent.__class__.__name__,
)
service = Service({}, global_config, parent)
for json in jsons:
config = self.load_config(json, ConfigSchema.SERVICE)
Expand Down
7 changes: 4 additions & 3 deletions mlos_bench/mlos_bench/storage/base_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from mlos_bench.services.base_service import Service
from mlos_bench.storage.base_experiment_data import ExperimentData
from mlos_bench.tunables.tunable_groups import TunableGroups
from mlos_bench.util import get_git_info
from mlos_bench.util import get_git_info, sanitize_config

_LOG = logging.getLogger(__name__)

Expand All @@ -62,7 +62,8 @@ def __init__(
config : dict
Free-format key/value pairs of configuration parameters.
"""
_LOG.debug("Storage config: %s", config)
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Storage config: %s", sanitize_config(config))
self._validate_json_config(config)
self._service = service
self._config = config.copy()
Expand Down Expand Up @@ -431,7 +432,7 @@ def new_trial(
_config = DictTemplater(config).expand_vars()
assert isinstance(_config, dict)
except ValueError as e:
_LOG.error("Non-serializable config: %s", config, exc_info=e)
_LOG.error("Non-serializable config: %s", sanitize_config(config), exc_info=e)
raise e
return self._new_trial(tunables, ts_start, config)

Expand Down
100 changes: 100 additions & 0 deletions mlos_bench/mlos_bench/tests/test_sanitize_confs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

Tests cover obfuscation of sensitive keys and recursive sanitization.
"""
import json5

from mlos_bench.util import sanitize_config


Expand All @@ -21,6 +23,7 @@ def test_sanitize_config_simple() -> None:
"other": 42,
}
sanitized = sanitize_config(config)
assert isinstance(sanitized, dict)
assert sanitized["username"] == "user1"
assert sanitized["password"] == "[REDACTED]"
assert sanitized["token"] == "[REDACTED]"
Expand All @@ -39,6 +42,7 @@ def test_sanitize_config_nested() -> None:
"api_key": "key",
}
sanitized = sanitize_config(config)
assert isinstance(sanitized, dict)
assert sanitized["outer"]["password"] == "[REDACTED]"
assert sanitized["outer"]["inner"]["token"] == "[REDACTED]"
assert sanitized["outer"]["inner"]["foo"] == "bar"
Expand All @@ -61,7 +65,103 @@ def test_sanitize_config_mixed_types() -> None:
"api_key": {"nested": "val"},
}
sanitized = sanitize_config(config)
assert isinstance(sanitized, dict)
assert sanitized["password"] == "[REDACTED]"
assert sanitized["token"] == "[REDACTED]"
assert sanitized["secret"] == "[REDACTED]"
assert sanitized["api_key"] == "[REDACTED]"


def test_sanitize_config_empty() -> None:
"""Test sanitization of an empty configuration."""
config: dict = {}
sanitized = sanitize_config(config)
assert sanitized == config # Should remain empty dictionary


def test_sanitize_array() -> None:
"""Test sanitization of an array with sensitive keys."""
config = [
{"username": "user1", "password": "pass1"},
{"username": "user2", "password": "pass2"},
]
sanitized = sanitize_config(config)
assert isinstance(sanitized, list)
assert len(sanitized) == 2
assert sanitized[0]["username"] == "user1"
assert sanitized[0]["password"] == "[REDACTED]"
assert sanitized[1]["username"] == "user2"
assert sanitized[1]["password"] == "[REDACTED]"


def test_sanitize_config_with_non_string_values() -> None:
"""Test sanitization with non-string values."""
config = {
"int_value": 42,
"float_value": 3.14,
"bool_value": True,
"none_value": None,
"list_value": [1, "password", 3],
"dict_value": {"key": "value"},
}
sanitized = sanitize_config(config)
assert isinstance(sanitized, dict)
assert sanitized["int_value"] == 42
assert sanitized["float_value"] == 3.14
assert sanitized["bool_value"] is True
assert sanitized["none_value"] is None
assert sanitized["list_value"] == [1, "password", 3] # don't redact raw strings
assert sanitized["dict_value"] == {"key": "value"}


def test_sanitize_config_json_string() -> None:
"""Test sanitization when input is a JSON string."""
config = {
"username": "user1",
"password": "mypassword",
"token": "abc123",
"nested": {"api_key": "key", "other": 1},
"list": [{"secret": "shh"}, {"foo": "bar"}],
}
config_json = json5.dumps(config)
sanitized = sanitize_config(config_json)
# Should return a JSON string
assert isinstance(sanitized, str)
sanitized_dict = json5.loads(sanitized)
assert isinstance(sanitized_dict, dict)
assert sanitized_dict["username"] == "user1"
assert sanitized_dict["password"] == "[REDACTED]"
assert sanitized_dict["token"] == "[REDACTED]"
assert sanitized_dict["nested"]["api_key"] == "[REDACTED]"
assert sanitized_dict["nested"]["other"] == 1
assert sanitized_dict["list"][0]["secret"] == "[REDACTED]"
assert sanitized_dict["list"][1]["foo"] == "bar"


def test_sanitize_config_invalid_json_string() -> None:
"""Test sanitization with an invalid JSON string input."""
invalid_json = '{"username": "user1", "password": "pw"' # missing closing brace
assert sanitize_config(invalid_json) == invalid_json


def test_sanitize_config_json5_string() -> None:
"""Test sanitization with an invalid JSON5 string input."""
invalid_json = '{"username": "user1", "password": "pw", }' # with trailing comma
# Should return processed json as string
sanitized = sanitize_config(invalid_json)
assert isinstance(sanitized, str)
sanitize_dict = json5.loads(sanitized)
assert isinstance(sanitize_dict, dict)
assert len(sanitize_dict) == 2
assert sanitize_dict["username"] == "user1"
assert sanitize_dict["password"] == "[REDACTED]"


def test_sanitize_config_json_string_no_sensitive_keys() -> None:
"""Test sanitization of a JSON string with no sensitive keys."""
config = {"foo": 1, "bar": {"baz": 2}}
config_json = json5.dumps(config)
sanitized = sanitize_config(config_json)
assert isinstance(sanitized, str)
sanitized_dict = json5.loads(sanitized)
assert sanitized_dict == config
87 changes: 68 additions & 19 deletions mlos_bench/mlos_bench/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union

import json5
import pandas
import pytz

Expand Down Expand Up @@ -581,32 +582,80 @@ def datetime_parser(
return new_datetime_col


def sanitize_config(config: dict[str, Any]) -> dict[str, Any]:
"""
Sanitize a configuration dictionary by obfuscating potentially sensitive keys.
_SANITIZE_KEYS = {
"password",
"secret",
"token",
"api_key",
}

Parameters
----------
config : dict
Configuration dictionary to sanitize.

Returns
-------
dict
Sanitized configuration dictionary.
"""
sanitize_keys = {"password", "secret", "token", "api_key"}

def recursive_sanitize(conf: dict[str, Any]) -> dict[str, Any]:
"""Recursively sanitize a dictionary."""
def _recursive_sanitize(
conf: dict[str, Any] | list[Any] | str,
) -> dict[str, Any] | list[Any] | str:
"""Recursively sanitize a dictionary."""
if isinstance(conf, str) and conf in _SANITIZE_KEYS:
return "[REDACTED]"
if isinstance(conf, list):
return [_recursive_sanitize(item) for item in conf]
if isinstance(conf, dict):
sanitized = {}
for k, v in conf.items():
if k in sanitize_keys:
if k in _SANITIZE_KEYS:
sanitized[k] = "[REDACTED]"
elif isinstance(v, dict):
sanitized[k] = recursive_sanitize(v) # type: ignore[assignment]
sanitized[k] = _recursive_sanitize(v) # type: ignore[assignment]
elif isinstance(v, list):
sanitized[k] = [
_recursive_sanitize(item) for item in v # type: ignore[assignment]
]
else:
sanitized[k] = v
return sanitized
# else, return un altered value (e.g., int, float, str)
return conf


def sanitize_config(config: dict[str, Any] | list[Any] | Any) -> dict[str, Any] | list[Any] | Any:
"""
Attempts to sanitize a configuration dictionary by obfuscating potentially sensitive
keys.

return recursive_sanitize(config)
Notes
-----
Mostly used to make CodeQL scans happy by redacting sensitive information
(e.g., passwords, tokens, API keys) in the configuration.

Will also attempt to parse the input as a JSON string if it is a string,
and return a JSON string if the original input was a JSON string.
Therefore this function is somewhat expensive so logging should be blocked with
``if _LOG.isEnabledFor(logging.INFO):`` checks (or similar) before calling it.

Finally, it will also replace bare strings that match the sensitive keys
with "[REDACTED]" to avoid leaking sensitive information in the logs, though
this is obviously a less effective approach and may hinder useful debugging.

Parameters
----------
config : dict | list | Any
Configuration dictionary to sanitize.

Returns
-------
dict | list | Any
Sanitized configuration dictionary.
"""
# Try and parse the config as a JSON string first, if it's a string.
was_json = False
if isinstance(config, str) and config:
try:
config = json5.loads(config)
was_json = True
except (json.JSONDecodeError, ValueError):
# If it fails to parse, use the original string.
pass
sanitized = _recursive_sanitize(config)
if was_json:
# If the original config was a JSON string, return it as a JSON string.
return json.dumps(sanitized, indent=2)
return sanitized
Loading