diff --git a/docs/index.md b/docs/index.md index b78cfbd7..e5b39fa1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1970,6 +1970,45 @@ class AzureKeyVaultSettings(BaseSettings): ) ``` +### Snake case conversion + +The Azure Key Vault source accepts a `snake_case_convertion` option, disabled by default, to convert Key Vault secret names by mapping them to Python's snake_case field names, without the need to use aliases. + +```py +import os + +from azure.identity import DefaultAzureCredential + +from pydantic_settings import ( + AzureKeyVaultSettingsSource, + BaseSettings, + PydanticBaseSettingsSource, +) + + +class AzureKeyVaultSettings(BaseSettings): + my_setting: str + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + az_key_vault_settings = AzureKeyVaultSettingsSource( + settings_cls, + os.environ['AZURE_KEY_VAULT_URL'], + DefaultAzureCredential(), + snake_case_conversion=True, + ) + return (az_key_vault_settings,) +``` + +This setup will load Azure Key Vault secrets (e.g., `MySetting`, `mySetting`, `my-secret` or `MY-SECRET`), mapping them to the snake case version (`my_setting` in this case). + ### Dash to underscore mapping The Azure Key Vault source accepts a `dash_to_underscore` option, disabled by default, to support Key Vault kebab-case secret names by mapping them to Python's snake_case field names. When enabled, dashes (`-`) in secret names are mapped to underscores (`_`) in field names during validation. diff --git a/pydantic_settings/sources/providers/azure.py b/pydantic_settings/sources/providers/azure.py index 04f0bee5..c0c95064 100644 --- a/pydantic_settings/sources/providers/azure.py +++ b/pydantic_settings/sources/providers/azure.py @@ -5,6 +5,7 @@ from collections.abc import Iterator, Mapping from typing import TYPE_CHECKING, Optional +from pydantic.alias_generators import to_snake from pydantic.fields import FieldInfo from .env import EnvSettingsSource @@ -45,26 +46,42 @@ def __init__( self, secret_client: SecretClient, case_sensitive: bool, + snake_case_conversion: bool, ) -> None: self._loaded_secrets = {} self._secret_client = secret_client self._case_sensitive = case_sensitive + self._snake_case_conversion = snake_case_conversion self._secret_map: dict[str, str] = self._load_remote() def _load_remote(self) -> dict[str, str]: secret_names: Iterator[str] = ( secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled ) + + if self._snake_case_conversion: + return {to_snake(name): name for name in secret_names} + if self._case_sensitive: return {name: name for name in secret_names} + return {name.lower(): name for name in secret_names} def __getitem__(self, key: str) -> str | None: - if not self._case_sensitive: - key = key.lower() - if key not in self._loaded_secrets and key in self._secret_map: - self._loaded_secrets[key] = self._secret_client.get_secret(self._secret_map[key]).value - return self._loaded_secrets[key] + new_key = key + + if self._snake_case_conversion: + new_key = to_snake(key) + elif not self._case_sensitive: + new_key = key.lower() + + if new_key not in self._loaded_secrets: + if new_key in self._secret_map: + self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value + else: + raise KeyError(key) + + return self._loaded_secrets[new_key] def __len__(self) -> int: return len(self._secret_map) @@ -84,6 +101,7 @@ def __init__( credential: TokenCredential, dash_to_underscore: bool = False, case_sensitive: bool | None = None, + snake_case_conversion: bool = False, env_prefix: str | None = None, env_parse_none_str: str | None = None, env_parse_enums: bool | None = None, @@ -92,11 +110,12 @@ def __init__( self._url = url self._credential = credential self._dash_to_underscore = dash_to_underscore + self._snake_case_conversion = snake_case_conversion super().__init__( settings_cls, - case_sensitive=case_sensitive, + case_sensitive=False if snake_case_conversion else case_sensitive, env_prefix=env_prefix, - env_nested_delimiter='--', + env_nested_delimiter='__' if snake_case_conversion else '--', env_ignore_empty=False, env_parse_none_str=env_parse_none_str, env_parse_enums=env_parse_enums, @@ -104,11 +123,19 @@ def __init__( def _load_env_vars(self) -> Mapping[str, Optional[str]]: secret_client = SecretClient(vault_url=self._url, credential=self._credential) - return AzureKeyVaultMapping(secret_client, self.case_sensitive) + return AzureKeyVaultMapping( + secret_client=secret_client, + case_sensitive=self.case_sensitive, + snake_case_conversion=self._snake_case_conversion, + ) def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]: + if self._snake_case_conversion: + return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name)) + if self._dash_to_underscore: return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name)) + return super()._extract_field_info(field, field_name) def __repr__(self) -> str: diff --git a/tests/test_source_azure_key_vault.py b/tests/test_source_azure_key_vault.py index 89f48add..4a602ad8 100644 --- a/tests/test_source_azure_key_vault.py +++ b/tests/test_source_azure_key_vault.py @@ -205,3 +205,63 @@ def settings_customise_sources( assert settings.my_field == expected_secret_value assert settings.alias_field == expected_secret_value + + def test_snake_case_conversion(self, mocker: MockerFixture) -> None: + """Test that secret names are mapped to snake case in field names.""" + + class NestedModel(BaseModel): + nested_field: str + + class AzureKeyVaultSettings(BaseSettings): + my_field_from_kebab_case: str + my_field_from_pascal_case: str + my_field_from_camel_case: str + alias_field: str = Field(alias='Secret-Alias') + alias_field_2: str = Field(alias='another-SECRET-AliaS') + nested_model: NestedModel + + @classmethod + def settings_customise_sources( + cls, + settings_cls: type[BaseSettings], + init_settings: PydanticBaseSettingsSource, + env_settings: PydanticBaseSettingsSource, + dotenv_settings: PydanticBaseSettingsSource, + file_secret_settings: PydanticBaseSettingsSource, + ) -> tuple[PydanticBaseSettingsSource, ...]: + return ( + AzureKeyVaultSettingsSource( + settings_cls, + 'https://my-resource.vault.azure.net/', + DefaultAzureCredential(), + snake_case_conversion=True, + ), + ) + + expected_secrets = [ + type('', (), {'name': 'my-field-from-kebab-case', 'enabled': True}), + type('', (), {'name': 'MyFieldFromPascalCase', 'enabled': True}), + type('', (), {'name': 'myFieldFromCamelCase', 'enabled': True}), + type('', (), {'name': 'Secret-Alias', 'enabled': True}), + type('', (), {'name': 'another-SECRET-AliaS', 'enabled': True}), + type('', (), {'name': 'NestedModel--NestedField', 'enabled': True}), + ] + expected_secret_value = 'SecretValue' + + mocker.patch( + f'{AzureKeyVaultSettingsSource.__module__}.{SecretClient.list_properties_of_secrets.__qualname__}', + return_value=expected_secrets, + ) + mocker.patch( + f'{AzureKeyVaultSettingsSource.__module__}.{SecretClient.get_secret.__qualname__}', + return_value=KeyVaultSecret(SecretProperties(), expected_secret_value), + ) + + settings = AzureKeyVaultSettings() + + assert settings.my_field_from_kebab_case == expected_secret_value + assert settings.my_field_from_pascal_case == expected_secret_value + assert settings.my_field_from_camel_case == expected_secret_value + assert settings.alias_field == expected_secret_value + assert settings.alias_field_2 == expected_secret_value + assert settings.nested_model.nested_field == expected_secret_value