Skip to content

Commit 0ecac22

Browse files
Reimplement setting source classes (#15)
Co-authored-by: Samuel Colvin <[email protected]>
1 parent 79571e3 commit 0ecac22

File tree

4 files changed

+515
-271
lines changed

4 files changed

+515
-271
lines changed

pydantic_settings/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
11
import warnings
22

33
from .main import BaseSettings
4+
from .sources import (
5+
DotEnvSettingsSource,
6+
EnvSettingsSource,
7+
InitSettingsSource,
8+
PydanticBaseSettingsSource,
9+
SecretsSettingsSource,
10+
)
411
from .version import VERSION
512

6-
__all__ = ('BaseSettings',)
13+
__all__ = (
14+
'BaseSettings',
15+
'PydanticBaseSettingsSource',
16+
'InitSettingsSource',
17+
'SecretsSettingsSource',
18+
'EnvSettingsSource',
19+
'DotEnvSettingsSource',
20+
)
721

822
__version__ = VERSION
923
warnings.warn(

pydantic_settings/main.py

Lines changed: 40 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
1-
import os
1+
from __future__ import annotations as _annotations
2+
23
import warnings
3-
from pathlib import Path
4-
from typing import AbstractSet, Any, Callable, ClassVar, Dict, List, Mapping, Optional, Tuple, Type, Union
4+
from typing import AbstractSet, Any, ClassVar, Dict, List, Optional, Tuple, Type, Union
55

66
from pydantic.config import BaseConfig, Extra
77
from pydantic.fields import ModelField
88
from pydantic.main import BaseModel
9-
from pydantic.typing import StrPath, display_as_type, get_origin, is_union
10-
from pydantic.utils import deep_update, path_type, sequence_like
9+
from pydantic.typing import StrPath, display_as_type
10+
from pydantic.utils import deep_update, sequence_like
11+
12+
from .sources import (
13+
DotEnvSettingsSource,
14+
DotenvType,
15+
EnvSettingsSource,
16+
InitSettingsSource,
17+
PydanticBaseSettingsSource,
18+
SecretsSettingsSource,
19+
)
1120

1221
env_file_sentinel = str(object())
1322

14-
SettingsSourceCallable = Callable[['BaseSettings'], Dict[str, Any]]
15-
DotenvType = Union[StrPath, List[StrPath], Tuple[StrPath, ...]]
16-
17-
18-
class SettingsError(ValueError):
19-
pass
20-
2123

2224
class BaseSettings(BaseModel):
2325
"""
@@ -55,8 +57,16 @@ def _build_values(
5557
_secrets_dir: Optional[StrPath] = None,
5658
) -> Dict[str, Any]:
5759
# Configure built-in sources
58-
init_settings = InitSettingsSource(init_kwargs=init_kwargs)
60+
init_settings = InitSettingsSource(self.__class__, init_kwargs=init_kwargs)
5961
env_settings = EnvSettingsSource(
62+
self.__class__,
63+
env_nested_delimiter=(
64+
_env_nested_delimiter if _env_nested_delimiter is not None else self.__config__.env_nested_delimiter
65+
),
66+
env_prefix_len=len(self.__config__.env_prefix),
67+
)
68+
dotenv_settings = DotEnvSettingsSource(
69+
self.__class__,
6070
env_file=(_env_file if _env_file != env_file_sentinel else self.__config__.env_file),
6171
env_file_encoding=(
6272
_env_file_encoding if _env_file_encoding is not None else self.__config__.env_file_encoding
@@ -66,13 +76,20 @@ def _build_values(
6676
),
6777
env_prefix_len=len(self.__config__.env_prefix),
6878
)
69-
file_secret_settings = SecretsSettingsSource(secrets_dir=_secrets_dir or self.__config__.secrets_dir)
79+
80+
file_secret_settings = SecretsSettingsSource(
81+
self.__class__, secrets_dir=_secrets_dir or self.__config__.secrets_dir
82+
)
7083
# Provide a hook to set built-in sources priority and add / remove sources
7184
sources = self.__config__.customise_sources(
72-
init_settings=init_settings, env_settings=env_settings, file_secret_settings=file_secret_settings
85+
self.__class__,
86+
init_settings=init_settings,
87+
env_settings=env_settings,
88+
dotenv_settings=dotenv_settings,
89+
file_secret_settings=file_secret_settings,
7390
)
7491
if sources:
75-
return deep_update(*reversed([source(self) for source in sources]))
92+
return deep_update(*reversed([source() for source in sources]))
7693
else:
7794
# no one should mean to do this, but I think returning an empty dict is marginally preferable
7895
# to an informative error and much better than a confusing error
@@ -120,227 +137,17 @@ def prepare_field(cls, field: ModelField) -> None:
120137
@classmethod
121138
def customise_sources(
122139
cls,
123-
init_settings: SettingsSourceCallable,
124-
env_settings: SettingsSourceCallable,
125-
file_secret_settings: SettingsSourceCallable,
126-
) -> Tuple[SettingsSourceCallable, ...]:
127-
return init_settings, env_settings, file_secret_settings
140+
settings_cls: Type[BaseSettings],
141+
init_settings: PydanticBaseSettingsSource,
142+
env_settings: PydanticBaseSettingsSource,
143+
dotenv_settings: PydanticBaseSettingsSource,
144+
file_secret_settings: PydanticBaseSettingsSource,
145+
) -> Tuple[PydanticBaseSettingsSource, ...]:
146+
return init_settings, env_settings, dotenv_settings, file_secret_settings
128147

129148
@classmethod
130149
def parse_env_var(cls, field_name: str, raw_val: str) -> Any:
131150
return cls.json_loads(raw_val)
132151

133152
# populated by the metaclass using the Config class defined above, annotated here to help IDEs only
134153
__config__: ClassVar[Type[Config]]
135-
136-
137-
class InitSettingsSource:
138-
__slots__ = ('init_kwargs',)
139-
140-
def __init__(self, init_kwargs: Dict[str, Any]):
141-
self.init_kwargs = init_kwargs
142-
143-
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
144-
return self.init_kwargs
145-
146-
def __repr__(self) -> str:
147-
return f'InitSettingsSource(init_kwargs={self.init_kwargs!r})'
148-
149-
150-
class EnvSettingsSource:
151-
__slots__ = ('env_file', 'env_file_encoding', 'env_nested_delimiter', 'env_prefix_len')
152-
153-
def __init__(
154-
self,
155-
env_file: Optional[DotenvType],
156-
env_file_encoding: Optional[str],
157-
env_nested_delimiter: Optional[str] = None,
158-
env_prefix_len: int = 0,
159-
):
160-
self.env_file: Optional[DotenvType] = env_file
161-
self.env_file_encoding: Optional[str] = env_file_encoding
162-
self.env_nested_delimiter: Optional[str] = env_nested_delimiter
163-
self.env_prefix_len: int = env_prefix_len
164-
165-
def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901
166-
"""
167-
Build environment variables suitable for passing to the Model.
168-
"""
169-
d: Dict[str, Any] = {}
170-
171-
if settings.__config__.case_sensitive:
172-
env_vars: Mapping[str, Optional[str]] = os.environ
173-
else:
174-
env_vars = {k.lower(): v for k, v in os.environ.items()}
175-
176-
dotenv_vars = self._read_env_files(settings.__config__.case_sensitive)
177-
if dotenv_vars:
178-
env_vars = {**dotenv_vars, **env_vars}
179-
180-
for field in settings.__fields__.values():
181-
env_val: Optional[str] = None
182-
for env_name in field.field_info.extra['env_names']:
183-
env_val = env_vars.get(env_name)
184-
if env_val is not None:
185-
break
186-
187-
is_complex, allow_parse_failure = self.field_is_complex(field)
188-
if is_complex:
189-
if env_val is None:
190-
# field is complex but no value found so far, try explode_env_vars
191-
env_val_built = self.explode_env_vars(field, env_vars)
192-
if env_val_built:
193-
d[field.alias] = env_val_built
194-
else:
195-
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
196-
try:
197-
env_val = settings.__config__.parse_env_var(field.name, env_val)
198-
except ValueError as e:
199-
if not allow_parse_failure:
200-
raise SettingsError(f'error parsing env var "{env_name}"') from e
201-
202-
if isinstance(env_val, dict):
203-
d[field.alias] = deep_update(env_val, self.explode_env_vars(field, env_vars))
204-
else:
205-
d[field.alias] = env_val
206-
elif env_val is not None:
207-
# simplest case, field is not complex, we only need to add the value if it was found
208-
d[field.alias] = env_val
209-
210-
return d
211-
212-
def _read_env_files(self, case_sensitive: bool) -> Dict[str, Optional[str]]:
213-
env_files = self.env_file
214-
if env_files is None:
215-
return {}
216-
217-
if isinstance(env_files, (str, os.PathLike)):
218-
env_files = [env_files]
219-
220-
dotenv_vars = {}
221-
for env_file in env_files:
222-
env_path = Path(env_file).expanduser()
223-
if env_path.is_file():
224-
dotenv_vars.update(
225-
read_env_file(env_path, encoding=self.env_file_encoding, case_sensitive=case_sensitive)
226-
)
227-
228-
return dotenv_vars
229-
230-
def field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
231-
"""
232-
Find out if a field is complex, and if so whether JSON errors should be ignored
233-
"""
234-
if field.is_complex():
235-
allow_parse_failure = False
236-
elif is_union(get_origin(field.type_)) and field.sub_fields and any(f.is_complex() for f in field.sub_fields):
237-
allow_parse_failure = True
238-
else:
239-
return False, False
240-
241-
return True, allow_parse_failure
242-
243-
def explode_env_vars(self, field: ModelField, env_vars: Mapping[str, Optional[str]]) -> Dict[str, Any]:
244-
"""
245-
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
246-
247-
This is applied to a single field, hence filtering by env_var prefix.
248-
"""
249-
prefixes = [f'{env_name}{self.env_nested_delimiter}' for env_name in field.field_info.extra['env_names']]
250-
result: Dict[str, Any] = {}
251-
for env_name, env_val in env_vars.items():
252-
if not any(env_name.startswith(prefix) for prefix in prefixes):
253-
continue
254-
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
255-
env_name_without_prefix = env_name[self.env_prefix_len :]
256-
_, *keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter)
257-
env_var = result
258-
for key in keys:
259-
env_var = env_var.setdefault(key, {})
260-
env_var[last_key] = env_val
261-
262-
return result
263-
264-
def __repr__(self) -> str:
265-
return (
266-
f'EnvSettingsSource(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
267-
f'env_nested_delimiter={self.env_nested_delimiter!r})'
268-
)
269-
270-
271-
class SecretsSettingsSource:
272-
__slots__ = ('secrets_dir',)
273-
274-
def __init__(self, secrets_dir: Optional[StrPath]):
275-
self.secrets_dir: Optional[StrPath] = secrets_dir
276-
277-
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
278-
"""
279-
Build fields from "secrets" files.
280-
"""
281-
secrets: Dict[str, Optional[str]] = {}
282-
283-
if self.secrets_dir is None:
284-
return secrets
285-
286-
secrets_path = Path(self.secrets_dir).expanduser()
287-
288-
if not secrets_path.exists():
289-
warnings.warn(f'directory "{secrets_path}" does not exist')
290-
return secrets
291-
292-
if not secrets_path.is_dir():
293-
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type(secrets_path)}')
294-
295-
for field in settings.__fields__.values():
296-
for env_name in field.field_info.extra['env_names']:
297-
path = find_case_path(secrets_path, env_name, settings.__config__.case_sensitive)
298-
if not path:
299-
# path does not exist, we curently don't return a warning for this
300-
continue
301-
302-
if path.is_file():
303-
secret_value = path.read_text().strip()
304-
if field.is_complex():
305-
try:
306-
secret_value = settings.__config__.parse_env_var(field.name, secret_value)
307-
except ValueError as e:
308-
raise SettingsError(f'error parsing env var "{env_name}"') from e
309-
310-
secrets[field.alias] = secret_value
311-
else:
312-
warnings.warn(
313-
f'attempted to load secret file "{path}" but found a {path_type(path)} instead.',
314-
stacklevel=4,
315-
)
316-
return secrets
317-
318-
def __repr__(self) -> str:
319-
return f'SecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
320-
321-
322-
def read_env_file(
323-
file_path: StrPath, *, encoding: str = None, case_sensitive: bool = False
324-
) -> Dict[str, Optional[str]]:
325-
try:
326-
from dotenv import dotenv_values
327-
except ImportError as e:
328-
raise ImportError('python-dotenv is not installed, run `pip install pydantic[dotenv]`') from e
329-
330-
file_vars: Dict[str, Optional[str]] = dotenv_values(file_path, encoding=encoding or 'utf8')
331-
if not case_sensitive:
332-
return {k.lower(): v for k, v in file_vars.items()}
333-
else:
334-
return file_vars
335-
336-
337-
def find_case_path(dir_path: Path, file_name: str, case_sensitive: bool) -> Optional[Path]:
338-
"""
339-
Find a file within path's directory matching filename, optionally ignoring case.
340-
"""
341-
for f in dir_path.iterdir():
342-
if f.name == file_name:
343-
return f
344-
elif not case_sensitive and f.name.lower() == file_name.lower():
345-
return f
346-
return None

0 commit comments

Comments
 (0)