Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions pydantic_settings/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ class SettingsConfigDict(ConfigDict, total=False):
json_file_encoding: str | None
yaml_file: PathType | None
yaml_file_encoding: str | None
yaml_config_section: str | None
"""
Specifies the top-level key in a YAML file from which to load the settings.
If provided, the settings will be loaded from the nested section under this key.
This is useful when the YAML file contains multiple configuration sections
and you only want to load a specific subset into your settings model.
"""

pyproject_toml_depth: int
"""
Number of levels **up** from the current working directory to attempt to find a pyproject.toml
Expand Down Expand Up @@ -446,6 +454,7 @@ def _settings_build_values(
json_file_encoding=None,
yaml_file=None,
yaml_file_encoding=None,
yaml_config_section=None,
toml_file=None,
secrets_dir=None,
protected_namespaces=('model_validate', 'model_dump', 'settings_customise_sources'),
Expand Down
14 changes: 14 additions & 0 deletions pydantic_settings/sources/providers/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,28 @@ def __init__(
settings_cls: type[BaseSettings],
yaml_file: PathType | None = DEFAULT_PATH,
yaml_file_encoding: str | None = None,
yaml_config_section: str | None = None,
):
self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file')
self.yaml_file_encoding = (
yaml_file_encoding
if yaml_file_encoding is not None
else settings_cls.model_config.get('yaml_file_encoding')
)
self.yaml_config_section = (
yaml_config_section
if yaml_config_section is not None
else settings_cls.model_config.get('yaml_config_section')
)
self.yaml_data = self._read_files(self.yaml_file_path)

if self.yaml_config_section:
try:
self.yaml_data = self.yaml_data[self.yaml_config_section]
except KeyError:
raise KeyError(
f'yaml_config_section key "{self.yaml_config_section}" not found in {self.yaml_file_path}'
)
super().__init__(settings_cls, self.yaml_data)

def _read_file(self, file_path: Path) -> dict[str, Any]:
Expand Down
62 changes: 62 additions & 0 deletions tests/test_source_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,65 @@ def settings_customise_sources(

s = Settings()
assert s.model_dump() == {'yaml3': 3, 'yaml4': 4}


@pytest.mark.skipif(yaml is None, reason='pyYAML is not installed')
def test_yaml_config_section(tmp_path):
p = tmp_path / '.env'
p.write_text(
"""
foobar: "Hello"
nested:
nested_field: "world!"
"""
)

class Settings(BaseSettings):
nested_field: str

model_config = SettingsConfigDict(yaml_file=p, yaml_config_section='nested')

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (YamlConfigSettingsSource(settings_cls),)

s = Settings()
assert s.nested_field == 'world!'


@pytest.mark.skipif(yaml is None, reason='pyYAML is not installed')
def test_invalid_yaml_config_section(tmp_path):
p = tmp_path / '.env'
p.write_text(
"""
foobar: "Hello"
nested:
nested_field: "world!"
"""
)

class Settings(BaseSettings):
nested_field: str

model_config = SettingsConfigDict(yaml_file=p, yaml_config_section='invalid_key')

@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
return (YamlConfigSettingsSource(settings_cls),)

with pytest.raises(KeyError, match=r'yaml_config_section key ".*" not found in .+'):
Settings()