|
1 | | -import os |
2 | | -import logging |
| 1 | +import warnings |
3 | 2 | from pathlib import Path |
4 | | -from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple |
| 3 | +from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional |
5 | 4 |
|
6 | | -from botocore.exceptions import ClientError |
7 | | -from botocore.client import Config |
8 | | -import boto3 |
9 | | - |
10 | | -from pydantic import BaseSettings |
11 | | -from pydantic.typing import StrPath, get_origin, is_union |
12 | | -from pydantic.utils import deep_update |
13 | | -from pydantic.fields import ModelField |
| 5 | +from pydantic_settings.sources import EnvSettingsSource |
14 | 6 |
|
15 | 7 | if TYPE_CHECKING: |
16 | | - from mypy_boto3_ssm.client import SSMClient |
17 | | - |
18 | | - |
19 | | -logger = logging.getLogger(__name__) |
| 8 | + try: |
| 9 | + from mypy_boto3_ssm import SSMClient |
| 10 | + except ImportError: |
| 11 | + ... |
20 | 12 |
|
21 | 13 |
|
22 | | -class SettingsError(ValueError): |
23 | | - pass |
| 14 | +class AwsSsmSettingsSource(EnvSettingsSource): |
| 15 | + DEFAULT_SSM_Path = "/" |
24 | 16 |
|
| 17 | + def __call__(self) -> Dict[str, Any]: |
| 18 | + return super().__call__() |
25 | 19 |
|
26 | | -class AwsSsmSettingsSource: |
27 | | - __slots__ = ("ssm_prefix", "env_nested_delimiter") |
28 | | - |
29 | | - def __init__( |
30 | | - self, |
31 | | - ssm_prefix: Optional[StrPath], |
32 | | - env_nested_delimiter: Optional[str] = None, |
33 | | - ): |
34 | | - self.ssm_prefix: Optional[StrPath] = ssm_prefix |
35 | | - self.env_nested_delimiter: Optional[str] = env_nested_delimiter |
| 20 | + def _get_source_arg(self, name: str) -> Any: |
| 21 | + """ |
| 22 | + Helper to retrieve source arguments from the settings class or the current state. |
| 23 | + """ |
| 24 | + return next( |
| 25 | + ( |
| 26 | + val |
| 27 | + for val in [ |
| 28 | + self.settings_cls.model_config.get(name), |
| 29 | + self.current_state.get(f"_{name}"), |
| 30 | + ] |
| 31 | + if val |
| 32 | + ), |
| 33 | + None, |
| 34 | + ) |
36 | 35 |
|
37 | 36 | @property |
38 | | - def client(self) -> "SSMClient": |
39 | | - return boto3.client("ssm", config=self.client_config) |
| 37 | + def _ssm_client(self) -> "SSMClient": |
| 38 | + client = self._get_source_arg("ssm_client") |
| 39 | + if client is None: |
| 40 | + raise ValueError( |
| 41 | + f"Required configuration 'ssm_client' not set on {self.__class__.__name__}" |
| 42 | + ) |
| 43 | + return client |
40 | 44 |
|
41 | 45 | @property |
42 | | - def client_config(self) -> Config: |
43 | | - timeout = float(os.environ.get("SSM_TIMEOUT", 0.5)) |
44 | | - return Config(connect_timeout=timeout, read_timeout=timeout) |
45 | | - |
46 | | - def load_from_ssm(self, secrets_path: Path, case_sensitive: bool): |
| 46 | + def _ssm_path(self) -> str: |
| 47 | + return self._get_source_arg("ssm_path") or self.DEFAULT_SSM_Path |
47 | 48 |
|
48 | | - if not secrets_path.is_absolute(): |
49 | | - raise ValueError("SSM prefix must be absolute path") |
| 49 | + # def get_field_value( |
| 50 | + # self, field: FieldInfo, field_name: str |
| 51 | + # ) -> Tuple[Any, str, bool]: ... |
50 | 52 |
|
51 | | - logger.debug(f"Building SSM settings with prefix of {secrets_path=}") |
| 53 | + def _load_env_vars(self) -> Mapping[str, Optional[str]]: |
| 54 | + paginator = self._ssm_client.get_paginator("get_parameters_by_path") |
| 55 | + response_iterator = paginator.paginate( |
| 56 | + Path=self._ssm_path, WithDecryption=True, Recursive=True |
| 57 | + ) |
52 | 58 |
|
53 | 59 | output = {} |
54 | 60 | try: |
55 | | - paginator = self.client.get_paginator("get_parameters_by_path") |
56 | | - response_iterator = paginator.paginate( |
57 | | - Path=str(secrets_path), WithDecryption=True |
58 | | - ) |
59 | | - |
60 | 61 | for page in response_iterator: |
61 | 62 | for parameter in page["Parameters"]: |
62 | | - key = Path(parameter["Name"]).relative_to(secrets_path).as_posix() |
63 | | - output[key if case_sensitive else key.lower()] = parameter["Value"] |
| 63 | + name = Path(parameter["Name"]) |
| 64 | + key = name.relative_to(self._ssm_path).as_posix() |
64 | 65 |
|
65 | | - except ClientError: |
66 | | - logger.exception("Failed to get parameters from %s", secrets_path) |
| 66 | + if not self.case_sensitive: |
| 67 | + first_key, *rest = key.split(self.env_nested_delimiter) |
| 68 | + key = self.env_nested_delimiter.join([first_key.lower(), *rest]) |
67 | 69 |
|
68 | | - return output |
69 | | - |
70 | | - def __call__(self, settings: BaseSettings) -> Dict[str, Any]: |
71 | | - """ |
72 | | - Returns SSM values for all settings. |
73 | | - """ |
74 | | - d: Dict[str, Optional[Any]] = {} |
75 | | - |
76 | | - if self.ssm_prefix is None: |
77 | | - return d |
78 | | - |
79 | | - ssm_values = self.load_from_ssm( |
80 | | - secrets_path=Path(self.ssm_prefix), |
81 | | - case_sensitive=settings.__config__.case_sensitive, |
82 | | - ) |
| 70 | + output[key] = parameter["Value"] |
83 | 71 |
|
84 | | - # The following was lifted from https://github.com/samuelcolvin/pydantic/blob/a21f0763ee877f0c86f254a5d60f70b1002faa68/pydantic/env_settings.py#L165-L237 # noqa |
85 | | - for field in settings.__fields__.values(): |
86 | | - env_val: Optional[str] = None |
87 | | - for env_name in field.field_info.extra["env_names"]: |
88 | | - env_val = ssm_values.get(env_name) |
89 | | - if env_val is not None: |
90 | | - break |
91 | | - |
92 | | - is_complex, allow_json_failure = self._field_is_complex(field) |
93 | | - if is_complex: |
94 | | - if env_val is None: |
95 | | - # field is complex but no value found so far, try explode_env_vars |
96 | | - env_val_built = self._explode_ssm_values(field, ssm_values) |
97 | | - if env_val_built: |
98 | | - d[field.alias] = env_val_built |
99 | | - else: |
100 | | - # field is complex and there's a value, decode that as JSON, then |
101 | | - # add explode_env_vars |
102 | | - try: |
103 | | - env_val = settings.__config__.json_loads(env_val) |
104 | | - except ValueError as e: |
105 | | - if not allow_json_failure: |
106 | | - raise SettingsError( |
107 | | - f'error parsing JSON for "{env_name}"' |
108 | | - ) from e |
109 | | - |
110 | | - if isinstance(env_val, dict): |
111 | | - d[field.alias] = deep_update( |
112 | | - env_val, self._explode_ssm_values(field, ssm_values) |
113 | | - ) |
114 | | - else: |
115 | | - d[field.alias] = env_val |
116 | | - elif env_val is not None: |
117 | | - # simplest case, field is not complex, we only need to add the |
118 | | - # value if it was found |
119 | | - d[field.alias] = env_val |
120 | | - |
121 | | - return d |
122 | | - |
123 | | - def _field_is_complex(self, field: ModelField) -> Tuple[bool, bool]: |
124 | | - """ |
125 | | - Find out if a field is complex, and if so whether JSON errors should be ignored |
126 | | - """ |
127 | | - if field.is_complex(): |
128 | | - allow_json_failure = False |
129 | | - elif ( |
130 | | - is_union(get_origin(field.type_)) |
131 | | - and field.sub_fields |
132 | | - and any(f.is_complex() for f in field.sub_fields) |
133 | | - ): |
134 | | - allow_json_failure = True |
135 | | - else: |
136 | | - return False, False |
137 | | - |
138 | | - return True, allow_json_failure |
139 | | - |
140 | | - def _explode_ssm_values( |
141 | | - self, field: ModelField, env_vars: Mapping[str, Optional[str]] |
142 | | - ) -> Dict[str, Any]: |
143 | | - """ |
144 | | - Process env_vars and extract the values of keys containing |
145 | | - env_nested_delimiter into nested dictionaries. |
| 72 | + except self._ssm_client.exceptions.ClientError as e: |
| 73 | + warnings.warn(f"Unable to get parameters from {self._ssm_path!r}: {e}") |
146 | 74 |
|
147 | | - This is applied to a single field, hence filtering by env_var prefix. |
148 | | - """ |
149 | | - prefixes = [ |
150 | | - f"{env_name}{self.env_nested_delimiter}" |
151 | | - for env_name in field.field_info.extra["env_names"] |
152 | | - ] |
153 | | - result: Dict[str, Any] = {} |
154 | | - for env_name, env_val in env_vars.items(): |
155 | | - if not any(env_name.startswith(prefix) for prefix in prefixes): |
156 | | - continue |
157 | | - _, *keys, last_key = env_name.split(self.env_nested_delimiter) |
158 | | - env_var = result |
159 | | - for key in keys: |
160 | | - env_var = env_var.setdefault(key, {}) |
161 | | - env_var[last_key] = env_val |
162 | | - |
163 | | - return result |
| 75 | + return output |
164 | 76 |
|
165 | 77 | def __repr__(self) -> str: |
166 | | - return f"AwsSsmSettingsSource(ssm_prefix={self.ssm_prefix!r})" |
| 78 | + return f"AwsSsmSettingsSource(ssm_path={self._ssm_path!r}, ssm_client={self._ssm_client!r})" |
0 commit comments