Skip to content

Commit 894741f

Browse files
committed
Snake case conversion in Azure Key Vault
1 parent 9c6c9b5 commit 894741f

File tree

2 files changed

+36
-35
lines changed

2 files changed

+36
-35
lines changed

pydantic_settings/sources/providers/azure.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections.abc import Iterator, Mapping
66
from typing import TYPE_CHECKING, Optional
77

8+
from pydantic.alias_generators import to_snake
89
from pydantic.fields import FieldInfo
910

1011
from .env import EnvSettingsSource
@@ -44,27 +45,27 @@ class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
4445
def __init__(
4546
self,
4647
secret_client: SecretClient,
47-
case_sensitive: bool,
4848
) -> None:
4949
self._loaded_secrets = {}
5050
self._secret_client = secret_client
51-
self._case_sensitive = case_sensitive
5251
self._secret_map: dict[str, str] = self._load_remote()
5352

5453
def _load_remote(self) -> dict[str, str]:
5554
secret_names: Iterator[str] = (
5655
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
5756
)
58-
if self._case_sensitive:
59-
return {name: name for name in secret_names}
60-
return {name.lower(): name for name in secret_names}
57+
return {to_snake(name): name for name in secret_names}
6158

6259
def __getitem__(self, key: str) -> str | None:
63-
if not self._case_sensitive:
64-
key = key.lower()
65-
if key not in self._loaded_secrets and key in self._secret_map:
66-
self._loaded_secrets[key] = self._secret_client.get_secret(self._secret_map[key]).value
67-
return self._loaded_secrets[key]
60+
key_snake = to_snake(key)
61+
62+
if key_snake not in self._loaded_secrets and key_snake in self._secret_map:
63+
self._loaded_secrets[key_snake] = self._secret_client.get_secret(self._secret_map[key_snake]).value
64+
65+
try:
66+
return self._loaded_secrets[key_snake]
67+
except Exception:
68+
raise KeyError(key)
6869

6970
def __len__(self) -> int:
7071
return len(self._secret_map)
@@ -82,34 +83,29 @@ def __init__(
8283
settings_cls: type[BaseSettings],
8384
url: str,
8485
credential: TokenCredential,
85-
dash_to_underscore: bool = False,
86-
case_sensitive: bool | None = None,
8786
env_prefix: str | None = None,
8887
env_parse_none_str: str | None = None,
8988
env_parse_enums: bool | None = None,
9089
) -> None:
9190
import_azure_key_vault()
9291
self._url = url
9392
self._credential = credential
94-
self._dash_to_underscore = dash_to_underscore
9593
super().__init__(
9694
settings_cls,
97-
case_sensitive=case_sensitive,
95+
case_sensitive=False,
9896
env_prefix=env_prefix,
99-
env_nested_delimiter='--',
97+
env_nested_delimiter='__',
10098
env_ignore_empty=False,
10199
env_parse_none_str=env_parse_none_str,
102100
env_parse_enums=env_parse_enums,
103101
)
104102

105103
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
106104
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
107-
return AzureKeyVaultMapping(secret_client, self.case_sensitive)
105+
return AzureKeyVaultMapping(secret_client)
108106

109107
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
110-
if self._dash_to_underscore:
111-
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
112-
return super()._extract_field_info(field, field_name)
108+
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
113109

114110
def __repr__(self) -> str:
115111
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'

tests/test_source_azure_key_vault.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,13 @@ def test___call__(self, mocker: MockerFixture) -> None:
4646
"""Test __call__."""
4747

4848
class SqlServer(BaseModel):
49-
password: str = Field(..., alias='Password')
49+
password: str
5050

5151
class AzureKeyVaultSettings(BaseSettings):
5252
"""AzureKeyVault settings."""
5353

54-
SqlServerUser: str
55-
sql_server_user: str = Field(..., alias='SqlServerUser')
56-
sql_server: SqlServer = Field(..., alias='SqlServer')
54+
sql_server_user: str
55+
sql_server: SqlServer
5756

5857
expected_secrets = [
5958
type('', (), {'name': 'SqlServerUser', 'enabled': True}),
@@ -74,14 +73,14 @@ class AzureKeyVaultSettings(BaseSettings):
7473

7574
settings = obj()
7675

77-
assert settings['SqlServerUser'] == expected_secret_value
78-
assert settings['SqlServer']['Password'] == expected_secret_value
76+
assert settings['sql_server_user'] == expected_secret_value
77+
assert settings['sql_server']['password'] == expected_secret_value
7978

8079
def test_do_not_load_disabled_secrets(self, mocker: MockerFixture) -> None:
8180
class AzureKeyVaultSettings(BaseSettings):
8281
"""AzureKeyVault settings."""
8382

84-
SqlServerPassword: str
83+
sql_server_password: str
8584
DisabledSqlServerPassword: str
8685

8786
disabled_secret_name = 'SqlServerPassword'
@@ -108,14 +107,13 @@ def test_azure_key_vault_settings_source(self, mocker: MockerFixture) -> None:
108107
"""Test AzureKeyVaultSettingsSource."""
109108

110109
class SqlServer(BaseModel):
111-
password: str = Field(..., alias='Password')
110+
password: str
112111

113112
class AzureKeyVaultSettings(BaseSettings):
114113
"""AzureKeyVault settings."""
115114

116-
SqlServerUser: str
117-
sql_server_user: str = Field(..., alias='SqlServerUser')
118-
sql_server: SqlServer = Field(..., alias='SqlServer')
115+
sql_server_user: str
116+
sql_server: SqlServer
119117

120118
@classmethod
121119
def settings_customise_sources(
@@ -148,7 +146,6 @@ def settings_customise_sources(
148146

149147
settings = AzureKeyVaultSettings() # type: ignore
150148

151-
assert settings.SqlServerUser == expected_secret_value
152149
assert settings.sql_server_user == expected_secret_value
153150
assert settings.sql_server.password == expected_secret_value
154151

@@ -161,12 +158,17 @@ def _raise_resource_not_found_when_getting_parent_secret_name(self, secret_name:
161158

162159
return key_vault_secret
163160

164-
def test_dash_to_underscore_translation(self, mocker: MockerFixture) -> None:
165-
"""Test that dashes in secret names are mapped to underscores in field names."""
161+
def test_snake_case_translation(self, mocker: MockerFixture) -> None:
162+
"""Test that secret names are mapped to snake case in field names."""
163+
164+
class NestedModel(BaseModel):
165+
nested_field: str
166166

167167
class AzureKeyVaultSettings(BaseSettings):
168168
my_field: str
169-
alias_field: str = Field(..., alias='Secret-Alias')
169+
alias_field: str = Field(alias='Secret-Alias')
170+
alias_field_2: str = Field(alias='another-SECRET-AliaS')
171+
nested_model: NestedModel
170172

171173
@classmethod
172174
def settings_customise_sources(
@@ -182,13 +184,14 @@ def settings_customise_sources(
182184
settings_cls,
183185
'https://my-resource.vault.azure.net/',
184186
DefaultAzureCredential(),
185-
dash_to_underscore=True,
186187
),
187188
)
188189

189190
expected_secrets = [
190191
type('', (), {'name': 'my-field', 'enabled': True}),
191192
type('', (), {'name': 'Secret-Alias', 'enabled': True}),
193+
type('', (), {'name': 'another-SECRET-AliaS', 'enabled': True}),
194+
type('', (), {'name': 'NestedModel--NestedField', 'enabled': True}),
192195
]
193196
expected_secret_value = 'SecretValue'
194197

@@ -205,3 +208,5 @@ def settings_customise_sources(
205208

206209
assert settings.my_field == expected_secret_value
207210
assert settings.alias_field == expected_secret_value
211+
assert settings.alias_field_2 == expected_secret_value
212+
assert settings.nested_model.nested_field == expected_secret_value

0 commit comments

Comments
 (0)