Skip to content

Commit 4b436de

Browse files
authored
fix: handle extra SSM parameters (#11)
* bugfix: handle extra SSM parameters Avoid issues of `extra fields not permitted (type=value_error.extra)` when there are non-relevant params stored in SSM. * Fixes * Fix * Flake8 fix * Flake8 fix again * Cleanup docs
1 parent aee7145 commit 4b436de

File tree

2 files changed

+123
-20
lines changed

2 files changed

+123
-20
lines changed

pydantic_ssm_settings/settings.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@ def customise_sources(
2222
file_secret_settings: SecretsSettingsSource,
2323
) -> Tuple[SettingsSourceCallable, ...]:
2424

25+
ssm_settings = AwsSsmSettingsSource(
26+
ssm_prefix=file_secret_settings.secrets_dir,
27+
env_nested_delimiter=env_settings.env_nested_delimiter,
28+
)
29+
2530
return (
2631
init_settings,
2732
env_settings,
@@ -30,5 +35,5 @@ def customise_sources(
3035
# about unexpected arguments. `secrets_dir` comes from `_secrets_dir`,
3136
# one of the few special kwargs that Pydantic will allow:
3237
# https://github.com/samuelcolvin/pydantic/blob/45db4ad3aa558879824a91dd3b011d0449eb2977/pydantic/env_settings.py#L33
33-
AwsSsmSettingsSource(ssm_prefix=file_secret_settings.secrets_dir),
38+
ssm_settings,
3439
)

pydantic_ssm_settings/source.py

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
import os
22
import logging
33
from pathlib import Path
4-
from typing import TYPE_CHECKING, Any, Dict, Optional
4+
from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Tuple
55

66
from botocore.exceptions import ClientError
77
from botocore.client import Config
88
import boto3
99

10-
from pydantic import BaseSettings, typing
10+
from pydantic import BaseSettings
11+
from pydantic.typing import StrPath, get_origin, is_union
12+
from pydantic.utils import deep_update
13+
from pydantic.fields import ModelField
1114

1215
if TYPE_CHECKING:
1316
from mypy_boto3_ssm.client import SSMClient
@@ -16,11 +19,20 @@
1619
logger = logging.getLogger(__name__)
1720

1821

22+
class SettingsError(ValueError):
23+
pass
24+
25+
1926
class AwsSsmSettingsSource:
20-
__slots__ = ("ssm_prefix",)
27+
__slots__ = ("ssm_prefix", "env_nested_delimiter")
2128

22-
def __init__(self, ssm_prefix: Optional[typing.StrPath]):
23-
self.ssm_prefix: Optional[typing.StrPath] = ssm_prefix
29+
def __init__(
30+
self,
31+
ssm_prefix: Optional[StrPath],
32+
env_nested_delimiter: Optional[str] = None,
33+
):
34+
self.ssm_prefix: Optional[StrPath] = ssm_prefix
35+
self.env_nested_delimiter: Optional[str] = env_nested_delimiter
2436

2537
@property
2638
def client(self) -> "SSMClient":
@@ -31,38 +43,124 @@ def client_config(self) -> Config:
3143
timeout = float(os.environ.get("SSM_TIMEOUT", 0.5))
3244
return Config(connect_timeout=timeout, read_timeout=timeout)
3345

34-
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
35-
"""
36-
Returns lazy SSM values for all settings.
37-
"""
38-
secrets: Dict[str, Optional[Any]] = {}
39-
40-
if self.ssm_prefix is None:
41-
return secrets
42-
43-
secrets_path = Path(self.ssm_prefix)
46+
def load_from_ssm(self, secrets_path: Path, case_sensitive: bool):
4447

4548
if not secrets_path.is_absolute():
4649
raise ValueError("SSM prefix must be absolute path")
4750

4851
logger.debug(f"Building SSM settings with prefix of {secrets_path=}")
4952

53+
output = {}
5054
try:
5155
paginator = self.client.get_paginator("get_parameters_by_path")
5256
response_iterator = paginator.paginate(
5357
Path=str(secrets_path), WithDecryption=True
5458
)
5559

56-
output = {}
5760
for page in response_iterator:
5861
for parameter in page["Parameters"]:
5962
key = Path(parameter["Name"]).relative_to(secrets_path).as_posix()
60-
output[key] = parameter["Value"]
61-
return output
63+
output[key if case_sensitive else key.lower()] = parameter["Value"]
6264

6365
except ClientError:
6466
logger.exception("Failed to get parameters from %s", secrets_path)
65-
return {}
67+
68+
return output
69+
70+
def __call__(self, settings: BaseSettings) -> Dict[str, Any]:
71+
"""
72+
Returns SSM values for all settings.
73+
"""
74+
d: Dict[str, Optional[Any]] = {}
75+
76+
if self.ssm_prefix is None:
77+
return d
78+
79+
ssm_values = self.load_from_ssm(
80+
secrets_path=Path(self.ssm_prefix),
81+
case_sensitive=settings.__config__.case_sensitive,
82+
)
83+
84+
# The following was lifted from https://github.com/samuelcolvin/pydantic/blob/a21f0763ee877f0c86f254a5d60f70b1002faa68/pydantic/env_settings.py#L165-L237 # noqa
85+
for field in settings.__fields__.values():
86+
env_val: Optional[str] = None
87+
for env_name in field.field_info.extra["env_names"]:
88+
env_val = ssm_values.get(env_name)
89+
if env_val is not None:
90+
break
91+
92+
is_complex, allow_json_failure = self._field_is_complex(field)
93+
if is_complex:
94+
if env_val is None:
95+
# field is complex but no value found so far, try explode_env_vars
96+
env_val_built = self._explode_ssm_values(field, ssm_values)
97+
if env_val_built:
98+
d[field.alias] = env_val_built
99+
else:
100+
# field is complex and there's a value, decode that as JSON, then
101+
# add explode_env_vars
102+
try:
103+
env_val = settings.__config__.json_loads(env_val)
104+
except ValueError as e:
105+
if not allow_json_failure:
106+
raise SettingsError(
107+
f'error parsing JSON for "{env_name}"'
108+
) from e
109+
110+
if isinstance(env_val, dict):
111+
d[field.alias] = deep_update(
112+
env_val, self._explode_ssm_values(field, ssm_values)
113+
)
114+
else:
115+
d[field.alias] = env_val
116+
elif env_val is not None:
117+
# simplest case, field is not complex, we only need to add the
118+
# value if it was found
119+
d[field.alias] = env_val
120+
121+
return d
122+
123+
def _field_is_complex(self, field: ModelField) -> Tuple[bool, bool]:
124+
"""
125+
Find out if a field is complex, and if so whether JSON errors should be ignored
126+
"""
127+
if field.is_complex():
128+
allow_json_failure = False
129+
elif (
130+
is_union(get_origin(field.type_))
131+
and field.sub_fields
132+
and any(f.is_complex() for f in field.sub_fields)
133+
):
134+
allow_json_failure = True
135+
else:
136+
return False, False
137+
138+
return True, allow_json_failure
139+
140+
def _explode_ssm_values(
141+
self, field: ModelField, env_vars: Mapping[str, Optional[str]]
142+
) -> Dict[str, Any]:
143+
"""
144+
Process env_vars and extract the values of keys containing
145+
env_nested_delimiter into nested dictionaries.
146+
147+
This is applied to a single field, hence filtering by env_var prefix.
148+
"""
149+
prefixes = [
150+
f"{env_name}{self.env_nested_delimiter}"
151+
for env_name in field.field_info.extra["env_names"]
152+
]
153+
result: Dict[str, Any] = {}
154+
for env_name, env_val in env_vars.items():
155+
if not any(env_name.startswith(prefix) for prefix in prefixes):
156+
continue
157+
_, *keys, last_key = env_name.split(self.env_nested_delimiter)
158+
env_var = result
159+
for key in keys:
160+
env_var = env_var.setdefault(key, {})
161+
env_var[last_key] = env_val
162+
163+
return result
66164

67165
def __repr__(self) -> str:
68166
return f"AwsSsmSettingsSource(ssm_prefix={self.ssm_prefix!r})"

0 commit comments

Comments
 (0)