Skip to content

Commit e12349e

Browse files
committed
Remove breaking changes
1 parent 81cff9d commit e12349e

File tree

2 files changed

+47
-20
lines changed

2 files changed

+47
-20
lines changed

pydantic_settings/sources/providers/azure.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,30 +42,41 @@ class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
4242
_secret_client: SecretClient
4343
_secret_names: list[str]
4444

45-
def __init__(
46-
self,
47-
secret_client: SecretClient,
48-
) -> None:
45+
def __init__(self, secret_client: SecretClient, case_sensitive: bool, snake_case_conversion: bool) -> None:
4946
self._loaded_secrets = {}
5047
self._secret_client = secret_client
48+
self._case_sensitive = case_sensitive
49+
self._snake_case_conversion = snake_case_conversion
5150
self._secret_map: dict[str, str] = self._load_remote()
5251

5352
def _load_remote(self) -> dict[str, str]:
5453
secret_names: Iterator[str] = (
5554
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
5655
)
57-
return {to_snake(name): name for name in secret_names}
56+
57+
if self._snake_case_conversion:
58+
return {to_snake(name): name for name in secret_names}
59+
60+
if self._case_sensitive:
61+
return {name: name for name in secret_names}
62+
63+
return {name.lower(): name for name in secret_names}
5864

5965
def __getitem__(self, key: str) -> str | None:
60-
key_snake = to_snake(key)
66+
new_key = key
67+
68+
if self._snake_case_conversion:
69+
new_key = to_snake(key)
70+
elif not self._case_sensitive:
71+
new_key = key.lower()
6172

62-
if key_snake not in self._loaded_secrets:
63-
if key_snake in self._secret_map:
64-
self._loaded_secrets[key_snake] = self._secret_client.get_secret(self._secret_map[key_snake]).value
73+
if new_key not in self._loaded_secrets:
74+
if new_key in self._secret_map:
75+
self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
6576
else:
6677
raise KeyError(key)
6778

68-
return self._loaded_secrets[key_snake]
79+
return self._loaded_secrets[new_key]
6980

7081
def __len__(self) -> int:
7182
return len(self._secret_map)
@@ -83,29 +94,44 @@ def __init__(
8394
settings_cls: type[BaseSettings],
8495
url: str,
8596
credential: TokenCredential,
97+
dash_to_underscore: bool = False,
98+
case_sensitive: bool | None = None,
99+
snake_case_conversion: bool = False,
86100
env_prefix: str | None = None,
87101
env_parse_none_str: str | None = None,
88102
env_parse_enums: bool | None = None,
89103
) -> None:
90104
import_azure_key_vault()
91105
self._url = url
92106
self._credential = credential
107+
self._dash_to_underscore = dash_to_underscore
108+
self._snake_case_conversion = snake_case_conversion
93109
super().__init__(
94110
settings_cls,
95-
case_sensitive=False,
111+
case_sensitive=False if snake_case_conversion else case_sensitive,
96112
env_prefix=env_prefix,
97-
env_nested_delimiter='__',
113+
env_nested_delimiter='__' if snake_case_conversion else '--',
98114
env_ignore_empty=False,
99115
env_parse_none_str=env_parse_none_str,
100116
env_parse_enums=env_parse_enums,
101117
)
102118

103119
def _load_env_vars(self) -> Mapping[str, Optional[str]]:
104120
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
105-
return AzureKeyVaultMapping(secret_client)
121+
return AzureKeyVaultMapping(
122+
secret_client=secret_client,
123+
case_sensitive=self.case_sensitive,
124+
snake_case_conversion=self._snake_case_conversion,
125+
)
106126

107127
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
108-
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
128+
if self._snake_case_conversion:
129+
return list((x[0], x[0], x[2]) for x in super()._extract_field_info(field, field_name))
130+
131+
if self._dash_to_underscore:
132+
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
133+
134+
return super()._extract_field_info(field, field_name)
109135

110136
def __repr__(self) -> str:
111137
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'

tests/test_source_azure_key_vault.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,18 @@ def test___call__(self, mocker: MockerFixture) -> None:
4646
"""Test __call__."""
4747

4848
class SqlServer(BaseModel):
49-
password: str
49+
password: str = Field(..., alias='Password')
5050

5151
class AzureKeyVaultSettings(BaseSettings):
5252
"""AzureKeyVault settings."""
5353

54-
sql_server_user: str
55-
sql_server: SqlServer
54+
SqlServerUser: str
55+
# sql_server_user: str = Field(..., alias='SqlServerUser')
56+
# sql_server: SqlServer = Field(..., alias='SqlServer')
5657

5758
expected_secrets = [
5859
type('', (), {'name': 'SqlServerUser', 'enabled': True}),
59-
type('', (), {'name': 'SqlServer--Password', 'enabled': True}),
60+
# type('', (), {'name': 'SqlServer--Password', 'enabled': True}),
6061
]
6162
expected_secret_value = 'SecretValue'
6263
mocker.patch(
@@ -73,8 +74,8 @@ class AzureKeyVaultSettings(BaseSettings):
7374

7475
settings = obj()
7576

76-
assert settings['sql_server_user'] == expected_secret_value
77-
assert settings['sql_server']['password'] == expected_secret_value
77+
assert settings['SqlServerUser'] == expected_secret_value
78+
assert settings['SqlServer']['Password'] == expected_secret_value
7879

7980
def test_do_not_load_disabled_secrets(self, mocker: MockerFixture) -> None:
8081
class AzureKeyVaultSettings(BaseSettings):

0 commit comments

Comments
 (0)