|
10 | 10 | from abc import ABC, abstractmethod
|
11 | 11 | from argparse import SUPPRESS, ArgumentParser, HelpFormatter, Namespace, _SubParsersAction
|
12 | 12 | from collections import deque
|
13 |
| -from dataclasses import is_dataclass |
| 13 | +from dataclasses import asdict, is_dataclass |
14 | 14 | from enum import Enum
|
15 | 15 | from pathlib import Path
|
16 | 16 | from types import FunctionType
|
@@ -246,6 +246,74 @@ def __call__(self) -> dict[str, Any]:
|
246 | 246 | pass
|
247 | 247 |
|
248 | 248 |
|
| 249 | +class DefaultSettingsSource(PydanticBaseSettingsSource): |
| 250 | + """ |
| 251 | + Source class for loading default values. |
| 252 | + """ |
| 253 | + |
| 254 | + def __init__(self, settings_cls: type[BaseSettings]): |
| 255 | + super().__init__(settings_cls) |
| 256 | + self.defaults = self._get_defaults(settings_cls) |
| 257 | + |
| 258 | + def _get_defaults(self, settings_cls: type[BaseSettings]) -> dict[str, Any]: |
| 259 | + defaults: dict[str, Any] = {} |
| 260 | + if self.config.get('validate_default'): |
| 261 | + fields = ( |
| 262 | + settings_cls.__pydantic_fields__ if is_pydantic_dataclass(settings_cls) else settings_cls.model_fields |
| 263 | + ) |
| 264 | + for field_name, field_info in fields.items(): |
| 265 | + if field_info.validate_default is not False: |
| 266 | + resolved_name = self._get_resolved_name(field_name, field_info) |
| 267 | + if field_info.default not in (PydanticUndefined, None): |
| 268 | + if is_model_class(field_info.annotation): |
| 269 | + defaults[resolved_name] = field_info.default.model_dump() |
| 270 | + elif is_dataclass(field_info.annotation): |
| 271 | + defaults[resolved_name] = asdict(field_info.default) |
| 272 | + else: |
| 273 | + defaults[resolved_name] = field_info.default |
| 274 | + elif field_info.default_factory is not None: |
| 275 | + defaults[resolved_name] = field_info.default_factory |
| 276 | + return defaults |
| 277 | + |
| 278 | + def _get_resolved_name(self, field_name: str, field_info: FieldInfo) -> str: |
| 279 | + if not any((field_info.alias, field_info.validation_alias)): |
| 280 | + return field_name |
| 281 | + |
| 282 | + resolved_names: list[str] = [] |
| 283 | + is_alias_path_only: bool = True |
| 284 | + new_alias_paths: list[AliasPath] = [] |
| 285 | + for alias in (field_info.alias, field_info.validation_alias): |
| 286 | + if alias is None: |
| 287 | + continue |
| 288 | + elif isinstance(alias, str): |
| 289 | + resolved_names.append(alias) |
| 290 | + is_alias_path_only = False |
| 291 | + elif isinstance(alias, AliasChoices): |
| 292 | + for name in alias.choices: |
| 293 | + if isinstance(name, str): |
| 294 | + resolved_names.append(name) |
| 295 | + is_alias_path_only = False |
| 296 | + else: |
| 297 | + new_alias_paths.append(name) |
| 298 | + else: |
| 299 | + new_alias_paths.append(alias) |
| 300 | + for alias_path in new_alias_paths: |
| 301 | + name = cast(str, alias_path.path[0]) |
| 302 | + if not resolved_names and is_alias_path_only: |
| 303 | + resolved_names.append(name) |
| 304 | + return tuple(dict.fromkeys(resolved_names))[0] |
| 305 | + |
| 306 | + def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]: |
| 307 | + # Nothing to do here. Only implement the return statement to make mypy happy |
| 308 | + return None, '', False |
| 309 | + |
| 310 | + def __call__(self) -> dict[str, Any]: |
| 311 | + return self.defaults |
| 312 | + |
| 313 | + def __repr__(self) -> str: |
| 314 | + return f'DefaultSettingsSource(init_kwargs={self.defaults!r})' |
| 315 | + |
| 316 | + |
249 | 317 | class InitSettingsSource(PydanticBaseSettingsSource):
|
250 | 318 | """
|
251 | 319 | Source class for loading values provided during settings class initialization.
|
|
0 commit comments