diff --git a/mlos_bench/mlos_bench/environments/base_environment.py b/mlos_bench/mlos_bench/environments/base_environment.py index 094085c78b..232016497c 100644 --- a/mlos_bench/mlos_bench/environments/base_environment.py +++ b/mlos_bench/mlos_bench/environments/base_environment.py @@ -21,7 +21,7 @@ from mlos_bench.services.base_service import Service from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.tunables.tunable_types import TunableValue -from mlos_bench.util import instantiate_from_config, merge_parameters +from mlos_bench.util import instantiate_from_config, merge_parameters, sanitize_config if TYPE_CHECKING: from mlos_bench.services.types.config_loader_type import SupportsConfigLoading @@ -174,7 +174,11 @@ def __init__( # pylint: disable=too-many-arguments _LOG.debug("Parameters for '%s' :: %s", name, self._params) if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Config for: '%s'\n%s", name, json.dumps(self.config, indent=2)) + _LOG.debug( + "Config for: '%s'\n%s", + name, + json.dumps(sanitize_config(self.config), indent=2), + ) def _validate_json_config(self, config: dict, name: str) -> None: """Reconstructs a basic json config that this class might have been instantiated diff --git a/mlos_bench/mlos_bench/launcher.py b/mlos_bench/mlos_bench/launcher.py index c728ed7fb2..0b69a14950 100644 --- a/mlos_bench/mlos_bench/launcher.py +++ b/mlos_bench/mlos_bench/launcher.py @@ -31,7 +31,7 @@ from mlos_bench.storage.base_storage import Storage from mlos_bench.tunables.tunable_groups import TunableGroups from mlos_bench.tunables.tunable_types import TunableValue -from mlos_bench.util import try_parse_val +from mlos_bench.util import sanitize_config, try_parse_val _LOG_LEVEL = logging.INFO _LOG_FORMAT = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s %(levelname)s %(message)s" @@ -478,7 +478,8 @@ def _try_parse_extra_args(cmdline: Iterable[str]) -> dict[str, TunableValue]: # other CLI options to use as common python/json variable replacements. config = {k.replace("-", "_"): v for k, v in config.items()} - _LOG.debug("Parsed config: %s", config) + if _LOG.isEnabledFor(logging.DEBUG): + _LOG.debug("Parsed config: %s", sanitize_config(config)) return config def _load_config( diff --git a/mlos_bench/mlos_bench/services/base_service.py b/mlos_bench/mlos_bench/services/base_service.py index 41eebfbb98..fdc4eb7f79 100644 --- a/mlos_bench/mlos_bench/services/base_service.py +++ b/mlos_bench/mlos_bench/services/base_service.py @@ -16,7 +16,7 @@ from mlos_bench.config.schemas import ConfigSchema from mlos_bench.services.types.bound_method import BoundMethod from mlos_bench.services.types.config_loader_type import SupportsConfigLoading -from mlos_bench.util import instantiate_from_config +from mlos_bench.util import instantiate_from_config, sanitize_config _LOG = logging.getLogger(__name__) @@ -99,9 +99,21 @@ def __init__( self._config_loader_service = parent if _LOG.isEnabledFor(logging.DEBUG): - _LOG.debug("Service: %s Config:\n%s", self, json.dumps(self.config, indent=2)) - _LOG.debug("Service: %s Globals:\n%s", self, json.dumps(global_config or {}, indent=2)) - _LOG.debug("Service: %s Parent: %s", self, parent.pprint() if parent else None) + _LOG.debug( + "Service: %s Config:\n%s", + self, + json.dumps(sanitize_config(self.config), indent=2), + ) + _LOG.debug( + "Service: %s Globals:\n%s", + self, + json.dumps(sanitize_config(global_config or {}), indent=2), + ) + _LOG.debug( + "Service: %s Parent: %s", + self, + parent.pprint() if parent else None, + ) @staticmethod def merge_methods( diff --git a/mlos_bench/mlos_bench/services/config_persistence.py b/mlos_bench/mlos_bench/services/config_persistence.py index adbe76d60a..6da47d11de 100644 --- a/mlos_bench/mlos_bench/services/config_persistence.py +++ b/mlos_bench/mlos_bench/services/config_persistence.py @@ -191,11 +191,12 @@ def load_config( if any(c in json for c in ("{", "[")): # If the path contains braces, it is likely already a json string, # so just parse it. - _LOG.info("Load config from json string: %s", json) + if _LOG.isEnabledFor(logging.INFO): + _LOG.info("Load config from json string: %s", sanitize_config(json)) try: config: Any = json5.loads(json) except ValueError as ex: - _LOG.error("Failed to parse config from JSON string: %s", json) + _LOG.error("Failed to parse config from JSON string: %s", sanitize_config(json)) raise ValueError(f"Failed to parse config from JSON string: {json}") from ex else: json = self.resolve_path(json) @@ -225,7 +226,7 @@ def load_config( # (e.g. Azure ARM templates). del config["$schema"] else: - _LOG.warning("Config %s is not validated against a schema.", json) + _LOG.warning("Config %s is not validated against a schema.", sanitize_config(json)) return config # type: ignore[no-any-return] def prepare_class_load( @@ -707,7 +708,12 @@ def load_services( -------- mlos_bench.services : Examples of service configurations. """ - _LOG.info("Load services: %s parent: %s", jsons, parent.__class__.__name__) + if _LOG.isEnabledFor(logging.INFO): + _LOG.info( + "Load services: %s parent: %s", + sanitize_config(jsons), + parent.__class__.__name__, + ) service = Service({}, global_config, parent) for json in jsons: config = self.load_config(json, ConfigSchema.SERVICE) diff --git a/mlos_bench/mlos_bench/storage/base_storage.py b/mlos_bench/mlos_bench/storage/base_storage.py index de81bb94d7..6a8bcd0b66 100644 --- a/mlos_bench/mlos_bench/storage/base_storage.py +++ b/mlos_bench/mlos_bench/storage/base_storage.py @@ -38,7 +38,7 @@ from mlos_bench.services.base_service import Service from mlos_bench.storage.base_experiment_data import ExperimentData from mlos_bench.tunables.tunable_groups import TunableGroups -from mlos_bench.util import get_git_info +from mlos_bench.util import get_git_info, sanitize_config _LOG = logging.getLogger(__name__) @@ -62,7 +62,8 @@ def __init__( config : dict Free-format key/value pairs of configuration parameters. """ - _LOG.debug("Storage config: %s", config) + if _LOG.isEnabledFor(logging.DEBUG): + _LOG.debug("Storage config: %s", sanitize_config(config)) self._validate_json_config(config) self._service = service self._config = config.copy() @@ -431,7 +432,7 @@ def new_trial( _config = DictTemplater(config).expand_vars() assert isinstance(_config, dict) except ValueError as e: - _LOG.error("Non-serializable config: %s", config, exc_info=e) + _LOG.error("Non-serializable config: %s", sanitize_config(config), exc_info=e) raise e return self._new_trial(tunables, ts_start, config) diff --git a/mlos_bench/mlos_bench/tests/test_sanitize_confs.py b/mlos_bench/mlos_bench/tests/test_sanitize_confs.py index a589d2124a..6ea9ed0f80 100644 --- a/mlos_bench/mlos_bench/tests/test_sanitize_confs.py +++ b/mlos_bench/mlos_bench/tests/test_sanitize_confs.py @@ -7,6 +7,8 @@ Tests cover obfuscation of sensitive keys and recursive sanitization. """ +import json5 + from mlos_bench.util import sanitize_config @@ -21,6 +23,7 @@ def test_sanitize_config_simple() -> None: "other": 42, } sanitized = sanitize_config(config) + assert isinstance(sanitized, dict) assert sanitized["username"] == "user1" assert sanitized["password"] == "[REDACTED]" assert sanitized["token"] == "[REDACTED]" @@ -39,6 +42,7 @@ def test_sanitize_config_nested() -> None: "api_key": "key", } sanitized = sanitize_config(config) + assert isinstance(sanitized, dict) assert sanitized["outer"]["password"] == "[REDACTED]" assert sanitized["outer"]["inner"]["token"] == "[REDACTED]" assert sanitized["outer"]["inner"]["foo"] == "bar" @@ -61,7 +65,103 @@ def test_sanitize_config_mixed_types() -> None: "api_key": {"nested": "val"}, } sanitized = sanitize_config(config) + assert isinstance(sanitized, dict) assert sanitized["password"] == "[REDACTED]" assert sanitized["token"] == "[REDACTED]" assert sanitized["secret"] == "[REDACTED]" assert sanitized["api_key"] == "[REDACTED]" + + +def test_sanitize_config_empty() -> None: + """Test sanitization of an empty configuration.""" + config: dict = {} + sanitized = sanitize_config(config) + assert sanitized == config # Should remain empty dictionary + + +def test_sanitize_array() -> None: + """Test sanitization of an array with sensitive keys.""" + config = [ + {"username": "user1", "password": "pass1"}, + {"username": "user2", "password": "pass2"}, + ] + sanitized = sanitize_config(config) + assert isinstance(sanitized, list) + assert len(sanitized) == 2 + assert sanitized[0]["username"] == "user1" + assert sanitized[0]["password"] == "[REDACTED]" + assert sanitized[1]["username"] == "user2" + assert sanitized[1]["password"] == "[REDACTED]" + + +def test_sanitize_config_with_non_string_values() -> None: + """Test sanitization with non-string values.""" + config = { + "int_value": 42, + "float_value": 3.14, + "bool_value": True, + "none_value": None, + "list_value": [1, "password", 3], + "dict_value": {"key": "value"}, + } + sanitized = sanitize_config(config) + assert isinstance(sanitized, dict) + assert sanitized["int_value"] == 42 + assert sanitized["float_value"] == 3.14 + assert sanitized["bool_value"] is True + assert sanitized["none_value"] is None + assert sanitized["list_value"] == [1, "password", 3] # don't redact raw strings + assert sanitized["dict_value"] == {"key": "value"} + + +def test_sanitize_config_json_string() -> None: + """Test sanitization when input is a JSON string.""" + config = { + "username": "user1", + "password": "mypassword", + "token": "abc123", + "nested": {"api_key": "key", "other": 1}, + "list": [{"secret": "shh"}, {"foo": "bar"}], + } + config_json = json5.dumps(config) + sanitized = sanitize_config(config_json) + # Should return a JSON string + assert isinstance(sanitized, str) + sanitized_dict = json5.loads(sanitized) + assert isinstance(sanitized_dict, dict) + assert sanitized_dict["username"] == "user1" + assert sanitized_dict["password"] == "[REDACTED]" + assert sanitized_dict["token"] == "[REDACTED]" + assert sanitized_dict["nested"]["api_key"] == "[REDACTED]" + assert sanitized_dict["nested"]["other"] == 1 + assert sanitized_dict["list"][0]["secret"] == "[REDACTED]" + assert sanitized_dict["list"][1]["foo"] == "bar" + + +def test_sanitize_config_invalid_json_string() -> None: + """Test sanitization with an invalid JSON string input.""" + invalid_json = '{"username": "user1", "password": "pw"' # missing closing brace + assert sanitize_config(invalid_json) == invalid_json + + +def test_sanitize_config_json5_string() -> None: + """Test sanitization with an invalid JSON5 string input.""" + invalid_json = '{"username": "user1", "password": "pw", }' # with trailing comma + # Should return processed json as string + sanitized = sanitize_config(invalid_json) + assert isinstance(sanitized, str) + sanitize_dict = json5.loads(sanitized) + assert isinstance(sanitize_dict, dict) + assert len(sanitize_dict) == 2 + assert sanitize_dict["username"] == "user1" + assert sanitize_dict["password"] == "[REDACTED]" + + +def test_sanitize_config_json_string_no_sensitive_keys() -> None: + """Test sanitization of a JSON string with no sensitive keys.""" + config = {"foo": 1, "bar": {"baz": 2}} + config_json = json5.dumps(config) + sanitized = sanitize_config(config_json) + assert isinstance(sanitized, str) + sanitized_dict = json5.loads(sanitized) + assert sanitized_dict == config diff --git a/mlos_bench/mlos_bench/util.py b/mlos_bench/mlos_bench/util.py index fe09a013fd..8afea9f010 100644 --- a/mlos_bench/mlos_bench/util.py +++ b/mlos_bench/mlos_bench/util.py @@ -15,6 +15,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union +import json5 import pandas import pytz @@ -581,32 +582,80 @@ def datetime_parser( return new_datetime_col -def sanitize_config(config: dict[str, Any]) -> dict[str, Any]: - """ - Sanitize a configuration dictionary by obfuscating potentially sensitive keys. +_SANITIZE_KEYS = { + "password", + "secret", + "token", + "api_key", +} - Parameters - ---------- - config : dict - Configuration dictionary to sanitize. - Returns - ------- - dict - Sanitized configuration dictionary. - """ - sanitize_keys = {"password", "secret", "token", "api_key"} - - def recursive_sanitize(conf: dict[str, Any]) -> dict[str, Any]: - """Recursively sanitize a dictionary.""" +def _recursive_sanitize( + conf: dict[str, Any] | list[Any] | str, +) -> dict[str, Any] | list[Any] | str: + """Recursively sanitize a dictionary.""" + if isinstance(conf, str) and conf in _SANITIZE_KEYS: + return "[REDACTED]" + if isinstance(conf, list): + return [_recursive_sanitize(item) for item in conf] + if isinstance(conf, dict): sanitized = {} for k, v in conf.items(): - if k in sanitize_keys: + if k in _SANITIZE_KEYS: sanitized[k] = "[REDACTED]" elif isinstance(v, dict): - sanitized[k] = recursive_sanitize(v) # type: ignore[assignment] + sanitized[k] = _recursive_sanitize(v) # type: ignore[assignment] + elif isinstance(v, list): + sanitized[k] = [ + _recursive_sanitize(item) for item in v # type: ignore[assignment] + ] else: sanitized[k] = v return sanitized + # else, return un altered value (e.g., int, float, str) + return conf + + +def sanitize_config(config: dict[str, Any] | list[Any] | Any) -> dict[str, Any] | list[Any] | Any: + """ + Attempts to sanitize a configuration dictionary by obfuscating potentially sensitive + keys. - return recursive_sanitize(config) + Notes + ----- + Mostly used to make CodeQL scans happy by redacting sensitive information + (e.g., passwords, tokens, API keys) in the configuration. + + Will also attempt to parse the input as a JSON string if it is a string, + and return a JSON string if the original input was a JSON string. + Therefore this function is somewhat expensive so logging should be blocked with + ``if _LOG.isEnabledFor(logging.INFO):`` checks (or similar) before calling it. + + Finally, it will also replace bare strings that match the sensitive keys + with "[REDACTED]" to avoid leaking sensitive information in the logs, though + this is obviously a less effective approach and may hinder useful debugging. + + Parameters + ---------- + config : dict | list | Any + Configuration dictionary to sanitize. + + Returns + ------- + dict | list | Any + Sanitized configuration dictionary. + """ + # Try and parse the config as a JSON string first, if it's a string. + was_json = False + if isinstance(config, str) and config: + try: + config = json5.loads(config) + was_json = True + except (json.JSONDecodeError, ValueError): + # If it fails to parse, use the original string. + pass + sanitized = _recursive_sanitize(config) + if was_json: + # If the original config was a JSON string, return it as a JSON string. + return json.dumps(sanitized, indent=2) + return sanitized