Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions pydantic_settings/sources/providers/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections.abc import Iterator, Mapping
from typing import TYPE_CHECKING, Optional

from pydantic.alias_generators import to_snake
from pydantic.fields import FieldInfo

from .env import EnvSettingsSource
Expand Down Expand Up @@ -44,27 +45,27 @@ class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
def __init__(
self,
secret_client: SecretClient,
case_sensitive: bool,
) -> None:
self._loaded_secrets = {}
self._secret_client = secret_client
self._case_sensitive = case_sensitive
self._secret_map: dict[str, str] = self._load_remote()

def _load_remote(self) -> dict[str, str]:
secret_names: Iterator[str] = (
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
)
if self._case_sensitive:
return {name: name for name in secret_names}
return {name.lower(): name for name in secret_names}
return {to_snake(name): name for name in secret_names}

def __getitem__(self, key: str) -> str | None:
if not self._case_sensitive:
key = key.lower()
if key not in self._loaded_secrets and key in self._secret_map:
self._loaded_secrets[key] = self._secret_client.get_secret(self._secret_map[key]).value
return self._loaded_secrets[key]
key_snake = to_snake(key)

if key_snake not in self._loaded_secrets and key_snake in self._secret_map:
self._loaded_secrets[key_snake] = self._secret_client.get_secret(self._secret_map[key_snake]).value

try:
return self._loaded_secrets[key_snake]
except Exception:
raise KeyError(key)

def __len__(self) -> int:
return len(self._secret_map)
Expand All @@ -82,34 +83,29 @@ def __init__(
settings_cls: type[BaseSettings],
url: str,
credential: TokenCredential,
dash_to_underscore: bool = False,
case_sensitive: bool | None = None,
env_prefix: str | None = None,
env_parse_none_str: str | None = None,
env_parse_enums: bool | None = None,
) -> None:
import_azure_key_vault()
self._url = url
self._credential = credential
self._dash_to_underscore = dash_to_underscore
super().__init__(
settings_cls,
case_sensitive=case_sensitive,
case_sensitive=False,
env_prefix=env_prefix,
env_nested_delimiter='--',
env_nested_delimiter='__',
env_ignore_empty=False,
env_parse_none_str=env_parse_none_str,
env_parse_enums=env_parse_enums,
)

def _load_env_vars(self) -> Mapping[str, Optional[str]]:
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
return AzureKeyVaultMapping(secret_client, self.case_sensitive)
return AzureKeyVaultMapping(secret_client)

def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
if self._dash_to_underscore:
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
return super()._extract_field_info(field, field_name)
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))

def __repr__(self) -> str:
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
Expand Down
37 changes: 21 additions & 16 deletions tests/test_source_azure_key_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ def test___call__(self, mocker: MockerFixture) -> None:
"""Test __call__."""

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')
password: str

class AzureKeyVaultSettings(BaseSettings):
"""AzureKeyVault settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')
sql_server_user: str
sql_server: SqlServer

expected_secrets = [
type('', (), {'name': 'SqlServerUser', 'enabled': True}),
Expand All @@ -74,14 +73,14 @@ class AzureKeyVaultSettings(BaseSettings):

settings = obj()

assert settings['SqlServerUser'] == expected_secret_value
assert settings['SqlServer']['Password'] == expected_secret_value
assert settings['sql_server_user'] == expected_secret_value
assert settings['sql_server']['password'] == expected_secret_value

def test_do_not_load_disabled_secrets(self, mocker: MockerFixture) -> None:
class AzureKeyVaultSettings(BaseSettings):
"""AzureKeyVault settings."""

SqlServerPassword: str
sql_server_password: str
DisabledSqlServerPassword: str

disabled_secret_name = 'SqlServerPassword'
Expand All @@ -108,14 +107,13 @@ def test_azure_key_vault_settings_source(self, mocker: MockerFixture) -> None:
"""Test AzureKeyVaultSettingsSource."""

class SqlServer(BaseModel):
password: str = Field(..., alias='Password')
password: str

class AzureKeyVaultSettings(BaseSettings):
"""AzureKeyVault settings."""

SqlServerUser: str
sql_server_user: str = Field(..., alias='SqlServerUser')
sql_server: SqlServer = Field(..., alias='SqlServer')
sql_server_user: str
sql_server: SqlServer

@classmethod
def settings_customise_sources(
Expand Down Expand Up @@ -148,7 +146,6 @@ def settings_customise_sources(

settings = AzureKeyVaultSettings() # type: ignore

assert settings.SqlServerUser == expected_secret_value
assert settings.sql_server_user == expected_secret_value
assert settings.sql_server.password == expected_secret_value

Expand All @@ -161,12 +158,17 @@ def _raise_resource_not_found_when_getting_parent_secret_name(self, secret_name:

return key_vault_secret

def test_dash_to_underscore_translation(self, mocker: MockerFixture) -> None:
"""Test that dashes in secret names are mapped to underscores in field names."""
def test_snake_case_translation(self, mocker: MockerFixture) -> None:
"""Test that secret names are mapped to snake case in field names."""

class NestedModel(BaseModel):
nested_field: str

class AzureKeyVaultSettings(BaseSettings):
my_field: str
alias_field: str = Field(..., alias='Secret-Alias')
alias_field: str = Field(alias='Secret-Alias')
alias_field_2: str = Field(alias='another-SECRET-AliaS')
nested_model: NestedModel

@classmethod
def settings_customise_sources(
Expand All @@ -182,13 +184,14 @@ def settings_customise_sources(
settings_cls,
'https://my-resource.vault.azure.net/',
DefaultAzureCredential(),
dash_to_underscore=True,
),
)

expected_secrets = [
type('', (), {'name': 'my-field', 'enabled': True}),
type('', (), {'name': 'Secret-Alias', 'enabled': True}),
type('', (), {'name': 'another-SECRET-AliaS', 'enabled': True}),
type('', (), {'name': 'NestedModel--NestedField', 'enabled': True}),
]
expected_secret_value = 'SecretValue'

Expand All @@ -205,3 +208,5 @@ def settings_customise_sources(

assert settings.my_field == expected_secret_value
assert settings.alias_field == expected_secret_value
assert settings.alias_field_2 == expected_secret_value
assert settings.nested_model.nested_field == expected_secret_value
Loading