@@ -42,30 +42,41 @@ class AzureKeyVaultMapping(Mapping[str, Optional[str]]):
4242 _secret_client : SecretClient
4343 _secret_names : list [str ]
4444
45- def __init__ (
46- self ,
47- secret_client : SecretClient ,
48- ) -> None :
45+ def __init__ (self , secret_client : SecretClient , case_sensitive : bool , snake_case_conversion : bool ) -> None :
4946 self ._loaded_secrets = {}
5047 self ._secret_client = secret_client
48+ self ._case_sensitive = case_sensitive
49+ self ._snake_case_conversion = snake_case_conversion
5150 self ._secret_map : dict [str , str ] = self ._load_remote ()
5251
5352 def _load_remote (self ) -> dict [str , str ]:
5453 secret_names : Iterator [str ] = (
5554 secret .name for secret in self ._secret_client .list_properties_of_secrets () if secret .name and secret .enabled
5655 )
57- return {to_snake (name ): name for name in secret_names }
56+
57+ if self ._snake_case_conversion :
58+ return {to_snake (name ): name for name in secret_names }
59+
60+ if self ._case_sensitive :
61+ return {name : name for name in secret_names }
62+
63+ return {name .lower (): name for name in secret_names }
5864
5965 def __getitem__ (self , key : str ) -> str | None :
60- key_snake = to_snake (key )
66+ new_key = key
67+
68+ if self ._snake_case_conversion :
69+ new_key = to_snake (key )
70+ elif not self ._case_sensitive :
71+ new_key = key .lower ()
6172
62- if key_snake not in self ._loaded_secrets :
63- if key_snake in self ._secret_map :
64- self ._loaded_secrets [key_snake ] = self ._secret_client .get_secret (self ._secret_map [key_snake ]).value
73+ if new_key not in self ._loaded_secrets :
74+ if new_key in self ._secret_map :
75+ self ._loaded_secrets [new_key ] = self ._secret_client .get_secret (self ._secret_map [new_key ]).value
6576 else :
6677 raise KeyError (key )
6778
68- return self ._loaded_secrets [key_snake ]
79+ return self ._loaded_secrets [new_key ]
6980
7081 def __len__ (self ) -> int :
7182 return len (self ._secret_map )
@@ -83,29 +94,44 @@ def __init__(
8394 settings_cls : type [BaseSettings ],
8495 url : str ,
8596 credential : TokenCredential ,
97+ dash_to_underscore : bool = False ,
98+ case_sensitive : bool | None = None ,
99+ snake_case_conversion : bool = False ,
86100 env_prefix : str | None = None ,
87101 env_parse_none_str : str | None = None ,
88102 env_parse_enums : bool | None = None ,
89103 ) -> None :
90104 import_azure_key_vault ()
91105 self ._url = url
92106 self ._credential = credential
107+ self ._dash_to_underscore = dash_to_underscore
108+ self ._snake_case_conversion = snake_case_conversion
93109 super ().__init__ (
94110 settings_cls ,
95- case_sensitive = False ,
111+ case_sensitive = False if snake_case_conversion else case_sensitive ,
96112 env_prefix = env_prefix ,
97- env_nested_delimiter = '__' ,
113+ env_nested_delimiter = '__' if snake_case_conversion else '--' ,
98114 env_ignore_empty = False ,
99115 env_parse_none_str = env_parse_none_str ,
100116 env_parse_enums = env_parse_enums ,
101117 )
102118
103119 def _load_env_vars (self ) -> Mapping [str , Optional [str ]]:
104120 secret_client = SecretClient (vault_url = self ._url , credential = self ._credential )
105- return AzureKeyVaultMapping (secret_client )
121+ return AzureKeyVaultMapping (
122+ secret_client = secret_client ,
123+ case_sensitive = self .case_sensitive ,
124+ snake_case_conversion = self ._snake_case_conversion ,
125+ )
106126
107127 def _extract_field_info (self , field : FieldInfo , field_name : str ) -> list [tuple [str , str , bool ]]:
108- return list ((x [0 ], x [0 ], x [2 ]) for x in super ()._extract_field_info (field , field_name ))
128+ if self ._snake_case_conversion :
129+ return list ((x [0 ], x [0 ], x [2 ]) for x in super ()._extract_field_info (field , field_name ))
130+
131+ if self ._dash_to_underscore :
132+ return list ((x [0 ], x [1 ].replace ('_' , '-' ), x [2 ]) for x in super ()._extract_field_info (field , field_name ))
133+
134+ return super ()._extract_field_info (field , field_name )
109135
110136 def __repr__ (self ) -> str :
111137 return f'{ self .__class__ .__name__ } (url={ self ._url !r} , env_nested_delimiter={ self .env_nested_delimiter !r} )'
0 commit comments