diff --git a/hlink/configs/load_config.py b/hlink/configs/load_config.py index 46b565a..d7baba8 100755 --- a/hlink/configs/load_config.py +++ b/hlink/configs/load_config.py @@ -7,11 +7,14 @@ from typing import Any import json import toml +import tomli from hlink.errors import UsageError -def load_conf_file(conf_name: str) -> tuple[Path, dict[str, Any]]: +def load_conf_file( + conf_name: str, *, use_legacy_toml_parser: bool = False +) -> tuple[Path, dict[str, Any]]: """Flexibly load a config file. Given a path `conf_name`, look for a file at that path. If that file @@ -20,8 +23,18 @@ def load_conf_file(conf_name: str) -> tuple[Path, dict[str, Any]]: name with a '.toml' extension added and load it if it exists. Then do the same for a file with a '.json' extension added. + `use_legacy_toml_parser` tells this function to use the legacy TOML library + which hlink used to use instead of the current default. This is provided + for backwards compatibility. Some previously written config files may + depend on bugs in the legacy TOML library, making it hard to migrate to the + new TOML v1.0 compliant parser. It is strongly recommended that new code + and config files use the default parser. Old code and config files should + also try to migrate to the default parser when possible. + Args: conf_name: the file to look for + use_legacy_toml_parser: (Not Recommended) Use the legacy, buggy TOML + parser instead of the default parser. Returns: a tuple (absolute path to the config file, contents of the config file) @@ -40,9 +53,19 @@ def load_conf_file(conf_name: str) -> tuple[Path, dict[str, Any]]: for file in existing_files: if file.suffix == ".toml": - with open(file) as f: - conf = toml.load(f) - return file.absolute(), conf + # Legacy support for using the "toml" library instead of "tomli". + # + # Eventually we should remove use_legacy_toml_parser and just use + # tomli or Python's standard library tomllib, which is available in + # Python 3.11+. + if use_legacy_toml_parser: + with open(file) as f: + conf = toml.load(f) + return file.absolute(), conf + else: + with open(file, "rb") as f: + conf = tomli.load(f) + return file.absolute(), conf if file.suffix == ".json": with open(file) as f: diff --git a/hlink/tests/config_loader_test.py b/hlink/tests/config_loader_test.py index b14e0b4..58ab53d 100644 --- a/hlink/tests/config_loader_test.py +++ b/hlink/tests/config_loader_test.py @@ -50,3 +50,18 @@ def test_load_conf_file_unrecognized_extension(tmp_path: Path) -> None: match="The file .+ exists, but it doesn't have a '.toml' or '.json' extension", ): load_conf_file(str(conf_file)) + + +def test_load_conf_file_json_legacy_parser(conf_dir_path: str) -> None: + """ + The use_legacy_toml_parser argument does not affect json parsing. + """ + conf_file = Path(conf_dir_path) / "test.json" + _, conf = load_conf_file(str(conf_file), use_legacy_toml_parser=True) + assert conf["id_column"] == "id" + + +def test_load_conf_file_toml_legacy_parser(conf_dir_path: str) -> None: + conf_file = Path(conf_dir_path) / "test1.toml" + _, conf = load_conf_file(str(conf_file), use_legacy_toml_parser=True) + assert conf["id_column"] == "id-toml" diff --git a/pyproject.toml b/pyproject.toml index 3364294..deab3e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "pyspark~=3.5.0", "scikit-learn>=1.1.0", "toml>=0.10.0", + "tomli>=2.0", ] [project.optional-dependencies]