diff --git a/pydantic_settings/main.py b/pydantic_settings/main.py index 7a38ea08..66cfbe90 100644 --- a/pydantic_settings/main.py +++ b/pydantic_settings/main.py @@ -62,6 +62,14 @@ class SettingsConfigDict(ConfigDict, total=False): json_file_encoding: str | None yaml_file: PathType | None yaml_file_encoding: str | None + yaml_config_section: str | None + """ + Specifies the top-level key in a YAML file from which to load the settings. + If provided, the settings will be loaded from the nested section under this key. + This is useful when the YAML file contains multiple configuration sections + and you only want to load a specific subset into your settings model. + """ + pyproject_toml_depth: int """ Number of levels **up** from the current working directory to attempt to find a pyproject.toml @@ -446,6 +454,7 @@ def _settings_build_values( json_file_encoding=None, yaml_file=None, yaml_file_encoding=None, + yaml_config_section=None, toml_file=None, secrets_dir=None, protected_namespaces=('model_validate', 'model_dump', 'settings_customise_sources'), diff --git a/pydantic_settings/sources/providers/yaml.py b/pydantic_settings/sources/providers/yaml.py index 2f936f73..82778b4f 100644 --- a/pydantic_settings/sources/providers/yaml.py +++ b/pydantic_settings/sources/providers/yaml.py @@ -39,6 +39,7 @@ def __init__( settings_cls: type[BaseSettings], yaml_file: PathType | None = DEFAULT_PATH, yaml_file_encoding: str | None = None, + yaml_config_section: str | None = None, ): self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file') self.yaml_file_encoding = ( @@ -46,7 +47,20 @@ def __init__( if yaml_file_encoding is not None else settings_cls.model_config.get('yaml_file_encoding') ) + self.yaml_config_section = ( + yaml_config_section + if yaml_config_section is not None + else settings_cls.model_config.get('yaml_config_section') + ) self.yaml_data = self._read_files(self.yaml_file_path) + + if self.yaml_config_section: + try: + self.yaml_data = self.yaml_data[self.yaml_config_section] + except KeyError: + raise KeyError( + f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}' + ) super().__init__(settings_cls, self.yaml_data) def _read_file(self, file_path: Path) -> dict[str, Any]: diff --git a/tests/test_source_yaml.py b/tests/test_source_yaml.py index b6ea1b72..fdedc39a 100644 --- a/tests/test_source_yaml.py +++ b/tests/test_source_yaml.py @@ -166,3 +166,65 @@ def settings_customise_sources( s = Settings() assert s.model_dump() == {'yaml3': 3, 'yaml4': 4} + + +@pytest.mark.skipif(yaml is None, reason='pyYAML is not installed') +def test_yaml_config_section(tmp_path): + p = tmp_path / '.env' + p.write_text( + """ + foobar: "Hello" + nested: + nested_field: "world!" + """ + ) + + class Settings(BaseSettings): + nested_field: str + + model_config = SettingsConfigDict(yaml_file=p, yaml_config_section='nested') + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + return (YamlConfigSettingsSource(settings_cls),) + + s = Settings() + assert s.nested_field == 'world!' + + +@pytest.mark.skipif(yaml is None, reason='pyYAML is not installed') +def test_invalid_yaml_config_section(tmp_path): + p = tmp_path / '.env' + p.write_text( + """ + foobar: "Hello" + nested: + nested_field: "world!" + """ + ) + + class Settings(BaseSettings): + nested_field: str + + model_config = SettingsConfigDict(yaml_file=p, yaml_config_section='invalid_key') + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + return (YamlConfigSettingsSource(settings_cls),) + + with pytest.raises(KeyError, match='yaml_config_section key "invalid_key" not found in .+'): + Settings()