11import os
22import logging
33from pathlib import Path
4- from typing import TYPE_CHECKING , Any , Dict , Optional
4+ from typing import TYPE_CHECKING , Any , Dict , Mapping , Optional , Tuple
55
66from botocore .exceptions import ClientError
77from botocore .client import Config
88import boto3
99
10- from pydantic import BaseSettings , typing
10+ from pydantic import BaseSettings
11+ from pydantic .typing import StrPath , get_origin , is_union
12+ from pydantic .utils import deep_update
13+ from pydantic .fields import ModelField
1114
1215if TYPE_CHECKING :
1316 from mypy_boto3_ssm .client import SSMClient
1619logger = logging .getLogger (__name__ )
1720
1821
22+ class SettingsError (ValueError ):
23+ pass
24+
25+
1926class AwsSsmSettingsSource :
20- __slots__ = ("ssm_prefix" ,)
27+ __slots__ = ("ssm_prefix" , "env_nested_delimiter" )
2128
22- def __init__ (self , ssm_prefix : Optional [typing .StrPath ]):
23- self .ssm_prefix : Optional [typing .StrPath ] = ssm_prefix
29+ def __init__ (
30+ self ,
31+ ssm_prefix : Optional [StrPath ],
32+ env_nested_delimiter : Optional [str ] = None ,
33+ ):
34+ self .ssm_prefix : Optional [StrPath ] = ssm_prefix
35+ self .env_nested_delimiter : Optional [str ] = env_nested_delimiter
2436
2537 @property
2638 def client (self ) -> "SSMClient" :
@@ -31,38 +43,124 @@ def client_config(self) -> Config:
3143 timeout = float (os .environ .get ("SSM_TIMEOUT" , 0.5 ))
3244 return Config (connect_timeout = timeout , read_timeout = timeout )
3345
34- def __call__ (self , settings : BaseSettings ) -> Dict [str , Any ]:
35- """
36- Returns lazy SSM values for all settings.
37- """
38- secrets : Dict [str , Optional [Any ]] = {}
39-
40- if self .ssm_prefix is None :
41- return secrets
42-
43- secrets_path = Path (self .ssm_prefix )
46+ def load_from_ssm (self , secrets_path : Path , case_sensitive : bool ):
4447
4548 if not secrets_path .is_absolute ():
4649 raise ValueError ("SSM prefix must be absolute path" )
4750
4851 logger .debug (f"Building SSM settings with prefix of { secrets_path = } " )
4952
53+ output = {}
5054 try :
5155 paginator = self .client .get_paginator ("get_parameters_by_path" )
5256 response_iterator = paginator .paginate (
5357 Path = str (secrets_path ), WithDecryption = True
5458 )
5559
56- output = {}
5760 for page in response_iterator :
5861 for parameter in page ["Parameters" ]:
5962 key = Path (parameter ["Name" ]).relative_to (secrets_path ).as_posix ()
60- output [key ] = parameter ["Value" ]
61- return output
63+ output [key if case_sensitive else key .lower ()] = parameter ["Value" ]
6264
6365 except ClientError :
6466 logger .exception ("Failed to get parameters from %s" , secrets_path )
65- return {}
67+
68+ return output
69+
70+ def __call__ (self , settings : BaseSettings ) -> Dict [str , Any ]:
71+ """
72+ Returns SSM values for all settings.
73+ """
74+ d : Dict [str , Optional [Any ]] = {}
75+
76+ if self .ssm_prefix is None :
77+ return d
78+
79+ ssm_values = self .load_from_ssm (
80+ secrets_path = Path (self .ssm_prefix ),
81+ case_sensitive = settings .__config__ .case_sensitive ,
82+ )
83+
84+ # The following was lifted from https://github.com/samuelcolvin/pydantic/blob/a21f0763ee877f0c86f254a5d60f70b1002faa68/pydantic/env_settings.py#L165-L237 # noqa
85+ for field in settings .__fields__ .values ():
86+ env_val : Optional [str ] = None
87+ for env_name in field .field_info .extra ["env_names" ]:
88+ env_val = ssm_values .get (env_name )
89+ if env_val is not None :
90+ break
91+
92+ is_complex , allow_json_failure = self ._field_is_complex (field )
93+ if is_complex :
94+ if env_val is None :
95+ # field is complex but no value found so far, try explode_env_vars
96+ env_val_built = self ._explode_ssm_values (field , ssm_values )
97+ if env_val_built :
98+ d [field .alias ] = env_val_built
99+ else :
100+ # field is complex and there's a value, decode that as JSON, then
101+ # add explode_env_vars
102+ try :
103+ env_val = settings .__config__ .json_loads (env_val )
104+ except ValueError as e :
105+ if not allow_json_failure :
106+ raise SettingsError (
107+ f'error parsing JSON for "{ env_name } "'
108+ ) from e
109+
110+ if isinstance (env_val , dict ):
111+ d [field .alias ] = deep_update (
112+ env_val , self ._explode_ssm_values (field , ssm_values )
113+ )
114+ else :
115+ d [field .alias ] = env_val
116+ elif env_val is not None :
117+ # simplest case, field is not complex, we only need to add the
118+ # value if it was found
119+ d [field .alias ] = env_val
120+
121+ return d
122+
123+ def _field_is_complex (self , field : ModelField ) -> Tuple [bool , bool ]:
124+ """
125+ Find out if a field is complex, and if so whether JSON errors should be ignored
126+ """
127+ if field .is_complex ():
128+ allow_json_failure = False
129+ elif (
130+ is_union (get_origin (field .type_ ))
131+ and field .sub_fields
132+ and any (f .is_complex () for f in field .sub_fields )
133+ ):
134+ allow_json_failure = True
135+ else :
136+ return False , False
137+
138+ return True , allow_json_failure
139+
140+ def _explode_ssm_values (
141+ self , field : ModelField , env_vars : Mapping [str , Optional [str ]]
142+ ) -> Dict [str , Any ]:
143+ """
144+ Process env_vars and extract the values of keys containing
145+ env_nested_delimiter into nested dictionaries.
146+
147+ This is applied to a single field, hence filtering by env_var prefix.
148+ """
149+ prefixes = [
150+ f"{ env_name } { self .env_nested_delimiter } "
151+ for env_name in field .field_info .extra ["env_names" ]
152+ ]
153+ result : Dict [str , Any ] = {}
154+ for env_name , env_val in env_vars .items ():
155+ if not any (env_name .startswith (prefix ) for prefix in prefixes ):
156+ continue
157+ _ , * keys , last_key = env_name .split (self .env_nested_delimiter )
158+ env_var = result
159+ for key in keys :
160+ env_var = env_var .setdefault (key , {})
161+ env_var [last_key ] = env_val
162+
163+ return result
66164
67165 def __repr__ (self ) -> str :
68166 return f"AwsSsmSettingsSource(ssm_prefix={ self .ssm_prefix !r} )"
0 commit comments